|
|
@ -1,19 +1,19 @@
|
|
|
|
from __future__ import annotations
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
|
|
import asyncio
|
|
|
|
import json
|
|
|
|
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
|
import time
|
|
|
|
import time
|
|
|
|
import traceback
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
from api.model.toolkit_ui.conversation import ConversationHelper
|
|
|
|
from local import noawait
|
|
|
|
from local import noawait
|
|
|
|
from typing import Optional, Callable, TypedDict
|
|
|
|
from typing import Optional, Callable, TypedDict
|
|
|
|
from aiohttp import web
|
|
|
|
from aiohttp import web
|
|
|
|
from sqlalchemy import select
|
|
|
|
from sqlalchemy import select
|
|
|
|
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
|
|
|
|
from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
|
|
|
|
from noawait import NoAwaitPool
|
|
|
|
from noawait import NoAwaitPool
|
|
|
|
from service.chat_complete import ChatCompleteService
|
|
|
|
from service.chat_complete import ChatCompleteService, ChatCompleteServiceResponse
|
|
|
|
from service.database import DatabaseService
|
|
|
|
from service.database import DatabaseService
|
|
|
|
from service.embedding_search import EmbeddingSearchArgs
|
|
|
|
from service.embedding_search import EmbeddingSearchArgs
|
|
|
|
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
|
|
|
|
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
|
|
|
|
from service.tiktoken import TikTokenService
|
|
|
|
from service.tiktoken import TikTokenService
|
|
|
|
import utils.web
|
|
|
|
import utils.web
|
|
|
|
|
|
|
|
|
|
|
@ -23,6 +23,7 @@ class ChatCompleteTask:
|
|
|
|
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
|
|
|
|
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
|
|
|
|
self.task_id = utils.web.generate_uuid()
|
|
|
|
self.task_id = utils.web.generate_uuid()
|
|
|
|
self.on_message: list[Callable] = []
|
|
|
|
self.on_message: list[Callable] = []
|
|
|
|
|
|
|
|
self.on_finished: list[Callable] = []
|
|
|
|
self.on_error: list[Callable] = []
|
|
|
|
self.on_error: list[Callable] = []
|
|
|
|
self.chunks: list[str] = []
|
|
|
|
self.chunks: list[str] = []
|
|
|
|
|
|
|
|
|
|
|
@ -37,7 +38,12 @@ class ChatCompleteTask:
|
|
|
|
self.transatcion_id: Optional[str] = None
|
|
|
|
self.transatcion_id: Optional[str] = None
|
|
|
|
self.point_cost: int = 0
|
|
|
|
self.point_cost: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
async def init(self, question: str, conversation_id: Optional[str] = None,
|
|
|
|
self.is_finished = False
|
|
|
|
|
|
|
|
self.finished_time: Optional[float] = None
|
|
|
|
|
|
|
|
self.result: Optional[ChatCompleteServiceResponse] = None
|
|
|
|
|
|
|
|
self.error: Optional[Exception] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None,
|
|
|
|
embedding_search: Optional[EmbeddingSearchArgs] = None):
|
|
|
|
embedding_search: Optional[EmbeddingSearchArgs] = None):
|
|
|
|
self.tiktoken = await TikTokenService.create()
|
|
|
|
self.tiktoken = await TikTokenService.create()
|
|
|
|
|
|
|
|
|
|
|
@ -54,34 +60,58 @@ class ChatCompleteTask:
|
|
|
|
self.transatcion_id: Optional[str] = None
|
|
|
|
self.transatcion_id: Optional[str] = None
|
|
|
|
self.point_cost: int = 0
|
|
|
|
self.point_cost: int = 0
|
|
|
|
if not self.is_system:
|
|
|
|
if not self.is_system:
|
|
|
|
usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete", question_tokens, extract_limit)
|
|
|
|
usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete",
|
|
|
|
self.transatcion_id = usage_res.get("transaction_id")
|
|
|
|
question_tokens, extract_limit)
|
|
|
|
self.point_cost = usage_res.get("point_cost")
|
|
|
|
self.transatcion_id = usage_res["transaction_id"]
|
|
|
|
|
|
|
|
self.point_cost = usage_res["point_cost"]
|
|
|
|
chat_res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
|
|
|
|
|
|
|
|
user_id=self.user_id, embedding_search=embedding_search)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return chat_res
|
|
|
|
res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
|
|
|
|
|
|
|
|
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
await self._exit()
|
|
|
|
await self._exit()
|
|
|
|
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
|
|
|
|
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
|
|
|
|
|
|
|
|
|
|
|
|
async def _on_message(self, delta_message: str):
|
|
|
|
async def _on_message(self, delta_message: str):
|
|
|
|
|
|
|
|
self.chunks.append(delta_message)
|
|
|
|
|
|
|
|
|
|
|
|
for callback in self.on_message:
|
|
|
|
for callback in self.on_message:
|
|
|
|
await callback(delta_message)
|
|
|
|
try:
|
|
|
|
|
|
|
|
await callback(delta_message)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
print("Error while processing on_message callback: %s" % e, file=sys.stderr)
|
|
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _on_finished(self):
|
|
|
|
|
|
|
|
for callback in self.on_finished:
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
await callback(self.result)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
print("Error while processing on_finished callback: %s" % e, file=sys.stderr)
|
|
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
async def _on_error(self, err: Exception):
|
|
|
|
async def _on_error(self, err: Exception):
|
|
|
|
|
|
|
|
self.error = err
|
|
|
|
for callback in self.on_error:
|
|
|
|
for callback in self.on_error:
|
|
|
|
await callback(err)
|
|
|
|
try:
|
|
|
|
|
|
|
|
await callback(err)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
print("Error while processing on_error callback: %s" % e, file=sys.stderr)
|
|
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
async def run(self):
|
|
|
|
async def run(self):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
chat_res = self.chat_complete.finish_chat_complete(self._on_message)
|
|
|
|
chat_res = await self.chat_complete.finish_chat_complete(self._on_message)
|
|
|
|
|
|
|
|
|
|
|
|
await self.chat_complete.set_latest_point_cost(self.point_cost)
|
|
|
|
await self.chat_complete.set_latest_point_cost(self.point_cost)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.result = chat_res
|
|
|
|
|
|
|
|
|
|
|
|
if self.transatcion_id:
|
|
|
|
if self.transatcion_id:
|
|
|
|
result = await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"])
|
|
|
|
await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
await self._on_finished()
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
err_msg = f"Error while processing chat complete request: {e}"
|
|
|
|
err_msg = f"Error while processing chat complete request: {e}"
|
|
|
|
|
|
|
|
|
|
|
@ -89,7 +119,7 @@ class ChatCompleteTask:
|
|
|
|
traceback.print_exc()
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
if self.transatcion_id:
|
|
|
|
if self.transatcion_id:
|
|
|
|
result = await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg)
|
|
|
|
await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg)
|
|
|
|
|
|
|
|
|
|
|
|
await self._on_error(e)
|
|
|
|
await self._on_error(e)
|
|
|
|
finally:
|
|
|
|
finally:
|
|
|
@ -98,10 +128,19 @@ class ChatCompleteTask:
|
|
|
|
async def _exit(self):
|
|
|
|
async def _exit(self):
|
|
|
|
await self.chat_complete_service.__aexit__(None, None, None)
|
|
|
|
await self.chat_complete_service.__aexit__(None, None, None)
|
|
|
|
del chat_complete_tasks[self.task_id]
|
|
|
|
del chat_complete_tasks[self.task_id]
|
|
|
|
|
|
|
|
self.is_finished = True
|
|
|
|
|
|
|
|
self.finished_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
@noawait.wrap
|
|
|
|
TASK_EXPIRE_TIME = 60 * 10
|
|
|
|
async def start(self):
|
|
|
|
|
|
|
|
await self.run()
|
|
|
|
async def chat_complete_task_gc():
|
|
|
|
|
|
|
|
now = time.time()
|
|
|
|
|
|
|
|
for task_id in chat_complete_tasks.keys():
|
|
|
|
|
|
|
|
task = chat_complete_tasks[task_id]
|
|
|
|
|
|
|
|
if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME:
|
|
|
|
|
|
|
|
del chat_complete_tasks[task_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noawait.add_timer(chat_complete_task_gc, 60)
|
|
|
|
|
|
|
|
|
|
|
|
class ChatComplete:
|
|
|
|
class ChatComplete:
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
@ -112,7 +151,7 @@ class ChatComplete:
|
|
|
|
"required": False,
|
|
|
|
"required": False,
|
|
|
|
"type": int
|
|
|
|
"type": int
|
|
|
|
},
|
|
|
|
},
|
|
|
|
"conversation_id": {
|
|
|
|
"id": {
|
|
|
|
"required": True,
|
|
|
|
"required": True,
|
|
|
|
"type": int
|
|
|
|
"type": int
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -123,40 +162,31 @@ class ChatComplete:
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
user_id = params.get("user_id")
|
|
|
|
user_id = params.get("user_id")
|
|
|
|
|
|
|
|
|
|
|
|
conversation_id = params.get("conversation_id")
|
|
|
|
conversation_id = params.get("id")
|
|
|
|
|
|
|
|
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
|
|
|
|
|
|
|
|
|
|
async with db.create_session() as session:
|
|
|
|
async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
|
|
|
|
stmt = select(ConversationModel).where(
|
|
|
|
conversation_info = await conversation_helper.find_by_id(conversation_id)
|
|
|
|
ConversationModel.id == conversation_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conversation_data = await session.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if conversation_data is None:
|
|
|
|
if conversation_info is None:
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
|
|
"code": "conversation-not-found",
|
|
|
|
"code": "conversation-not-found",
|
|
|
|
"message": "Conversation not found."
|
|
|
|
"message": "Conversation not found."
|
|
|
|
}, http_status=404, request=request)
|
|
|
|
}, http_status=404, request=request)
|
|
|
|
|
|
|
|
|
|
|
|
if conversation_data.user_id != user_id:
|
|
|
|
if conversation_info.user_id != user_id:
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
|
|
"code": "permission-denied",
|
|
|
|
"code": "permission-denied",
|
|
|
|
"message": "Permission denied."
|
|
|
|
"message": "Permission denied."
|
|
|
|
}, http_status=403, request=request)
|
|
|
|
}, http_status=403, request=request)
|
|
|
|
|
|
|
|
|
|
|
|
stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \
|
|
|
|
conversation_chunk_result = await conversation_chunk_helper.get_chunk_id_list(conversation_id)
|
|
|
|
.where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conversation_chunk_result = await session.scalars(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conversation_chunk_list = []
|
|
|
|
conversation_chunk_list = []
|
|
|
|
|
|
|
|
|
|
|
|
for result in conversation_chunk_result:
|
|
|
|
for result in conversation_chunk_result:
|
|
|
|
conversation_chunk_list.append({
|
|
|
|
conversation_chunk_list.append(result)
|
|
|
|
"id": result.id,
|
|
|
|
|
|
|
|
"updated_at": result.updated_at
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return await utils.web.api_response(1, conversation_chunk_list, request=request)
|
|
|
|
return await utils.web.api_response(1, conversation_chunk_list, request=request)
|
|
|
|
|
|
|
|
|
|
|
@ -181,26 +211,32 @@ class ChatComplete:
|
|
|
|
|
|
|
|
|
|
|
|
chunk_id = params.get("chunk_id")
|
|
|
|
chunk_id = params.get("chunk_id")
|
|
|
|
|
|
|
|
|
|
|
|
dbs = await DatabaseService.create(request.app)
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
|
|
async with dbs.create_session() as session:
|
|
|
|
async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
|
|
|
|
stmt = select(ConversationChunkModel).where(
|
|
|
|
chunk_info = await conversation_chunk_helper.find_by_id(chunk_id)
|
|
|
|
ConversationChunkModel.id == chunk_id)
|
|
|
|
if chunk_info is None:
|
|
|
|
|
|
|
|
|
|
|
|
conversation_data = await session.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if conversation_data is None:
|
|
|
|
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
|
|
"code": "conversation-chunk-not-found",
|
|
|
|
"code": "conversation-chunk-not-found",
|
|
|
|
"message": "Conversation chunk not found."
|
|
|
|
"message": "Conversation chunk not found."
|
|
|
|
}, http_status=404, request=request)
|
|
|
|
}, http_status=404, request=request)
|
|
|
|
|
|
|
|
|
|
|
|
if conversation_data.conversation.user_id != user_id:
|
|
|
|
conversation_info = await conversation_helper.find_by_id(chunk_info.conversation_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if conversation_info is not None and conversation_info.user_id != user_id:
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
|
|
"code": "permission-denied",
|
|
|
|
"code": "permission-denied",
|
|
|
|
"message": "Permission denied."
|
|
|
|
"message": "Permission denied."
|
|
|
|
}, http_status=403, request=request)
|
|
|
|
}, http_status=403, request=request)
|
|
|
|
|
|
|
|
|
|
|
|
return await utils.web.api_response(1, conversation_data.__dict__, request=request)
|
|
|
|
chunk_dict = {
|
|
|
|
|
|
|
|
"id": chunk_info.id,
|
|
|
|
|
|
|
|
"conversation_id": chunk_info.conversation_id,
|
|
|
|
|
|
|
|
"message_data": chunk_info.message_data,
|
|
|
|
|
|
|
|
"tokens": chunk_info.tokens,
|
|
|
|
|
|
|
|
"updated_at": chunk_info.updated_at,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return await utils.web.api_response(1, chunk_dict, request=request)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
@utils.web.token_auth
|
|
|
|
@utils.web.token_auth
|
|
|
@ -219,6 +255,47 @@ class ChatComplete:
|
|
|
|
|
|
|
|
|
|
|
|
return await utils.web.api_response(1, {"tokens": tokens}, request=request)
|
|
|
|
return await utils.web.api_response(1, {"tokens": tokens}, request=request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
@utils.web.token_auth
|
|
|
|
|
|
|
|
async def get_point_cost(request: web.Request):
|
|
|
|
|
|
|
|
params = await utils.web.get_param(request, {
|
|
|
|
|
|
|
|
"question": {
|
|
|
|
|
|
|
|
"type": str,
|
|
|
|
|
|
|
|
"required": True,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
"extract_limit": {
|
|
|
|
|
|
|
|
"type": int,
|
|
|
|
|
|
|
|
"required": False,
|
|
|
|
|
|
|
|
"default": 10,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
user_id = request.get("user")
|
|
|
|
|
|
|
|
caller = request.get("caller")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
question = params.get("question")
|
|
|
|
|
|
|
|
extract_limit = params.get("extract_limit")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tiktoken = await TikTokenService.create()
|
|
|
|
|
|
|
|
mwapi = MediaWikiApi.create()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens = await tiktoken.get_tokens(question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
res = await mwapi.chat_complete_get_point_cost(user_id, "chatcomplete", tokens, extract_limit)
|
|
|
|
|
|
|
|
return await utils.web.api_response(1, {
|
|
|
|
|
|
|
|
"point_cost": res["point_cost"],
|
|
|
|
|
|
|
|
"tokens": tokens,
|
|
|
|
|
|
|
|
}, request=request)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
err_msg = f"Error while get chat complete point cost: {e}"
|
|
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
|
|
|
|
|
|
"code": "chat-complete-error",
|
|
|
|
|
|
|
|
"message": err_msg
|
|
|
|
|
|
|
|
}, http_status=500, request=request)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
@utils.web.token_auth
|
|
|
|
@utils.web.token_auth
|
|
|
|
async def start_chat_complete(request: web.Request):
|
|
|
|
async def start_chat_complete(request: web.Request):
|
|
|
@ -245,6 +322,10 @@ class ChatComplete:
|
|
|
|
"required": False,
|
|
|
|
"required": False,
|
|
|
|
"default": False,
|
|
|
|
"default": False,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
|
|
|
|
"edit_message_id": {
|
|
|
|
|
|
|
|
"type": str,
|
|
|
|
|
|
|
|
"required": False,
|
|
|
|
|
|
|
|
},
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
user_id = request.get("user")
|
|
|
|
user_id = request.get("user")
|
|
|
@ -257,20 +338,25 @@ class ChatComplete:
|
|
|
|
extract_limit = params.get("extract_limit")
|
|
|
|
extract_limit = params.get("extract_limit")
|
|
|
|
in_collection = params.get("in_collection")
|
|
|
|
in_collection = params.get("in_collection")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
edit_message_id = params.get("edit_message_id")
|
|
|
|
|
|
|
|
|
|
|
|
dbs = await DatabaseService.create(request.app)
|
|
|
|
dbs = await DatabaseService.create(request.app)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user")
|
|
|
|
chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user")
|
|
|
|
init_res = await chat_complete_task.init(question, conversation_id=conversation_id, embedding_search={
|
|
|
|
init_res = await chat_complete_task.init(question, conversation_id=conversation_id, edit_message_id=edit_message_id,
|
|
|
|
|
|
|
|
embedding_search={
|
|
|
|
"limit": extract_limit,
|
|
|
|
"limit": extract_limit,
|
|
|
|
"in_collection": in_collection,
|
|
|
|
"in_collection": in_collection,
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task
|
|
|
|
chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task
|
|
|
|
|
|
|
|
|
|
|
|
chat_complete_task.start()
|
|
|
|
noawait.add_task(chat_complete_task.run())
|
|
|
|
|
|
|
|
|
|
|
|
return utils.web.api_response(1, data={
|
|
|
|
return await utils.web.api_response(1, data={
|
|
|
|
|
|
|
|
"conversation_id": init_res["conversation_id"],
|
|
|
|
|
|
|
|
"chunk_id": init_res["chunk_id"],
|
|
|
|
"question_tokens": init_res["question_tokens"],
|
|
|
|
"question_tokens": init_res["question_tokens"],
|
|
|
|
"extract_doc": init_res["extract_doc"],
|
|
|
|
"extract_doc": init_res["extract_doc"],
|
|
|
|
"task_id": chat_complete_task.task_id,
|
|
|
|
"task_id": chat_complete_task.task_id,
|
|
|
@ -282,6 +368,12 @@ class ChatComplete:
|
|
|
|
"title": page_title,
|
|
|
|
"title": page_title,
|
|
|
|
"message": error_msg
|
|
|
|
"message": error_msg
|
|
|
|
}, http_status=404, request=request)
|
|
|
|
}, http_status=404, request=request)
|
|
|
|
|
|
|
|
except MediaWikiUserNoEnoughPointsException as e:
|
|
|
|
|
|
|
|
error_msg = "Does not have enough points." % user_id
|
|
|
|
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
|
|
|
|
|
|
"code": "no-enough-points",
|
|
|
|
|
|
|
|
"message": error_msg
|
|
|
|
|
|
|
|
}, http_status=403, request=request)
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
err_msg = f"Error while processing chat complete request: {e}"
|
|
|
|
err_msg = f"Error while processing chat complete request: {e}"
|
|
|
|
traceback.print_exc()
|
|
|
|
traceback.print_exc()
|
|
|
@ -294,4 +386,123 @@ class ChatComplete:
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
@utils.web.token_auth
|
|
|
|
@utils.web.token_auth
|
|
|
|
async def chat_complete_stream(request: web.Request):
|
|
|
|
async def chat_complete_stream(request: web.Request):
|
|
|
|
pass
|
|
|
|
if not utils.web.is_websocket(request):
|
|
|
|
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
|
|
|
|
|
|
"code": "websocket-required",
|
|
|
|
|
|
|
|
"message": "This API only accept websocket connection."
|
|
|
|
|
|
|
|
}, http_status=400, request=request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params = await utils.web.get_param(request, {
|
|
|
|
|
|
|
|
"task_id": {
|
|
|
|
|
|
|
|
"type": str,
|
|
|
|
|
|
|
|
"required": True,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ws = web.WebSocketResponse()
|
|
|
|
|
|
|
|
await ws.prepare(request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task_id = params.get("task_id")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task = chat_complete_tasks.get(task_id)
|
|
|
|
|
|
|
|
if task is None:
|
|
|
|
|
|
|
|
await ws.send_json({
|
|
|
|
|
|
|
|
'event': 'error',
|
|
|
|
|
|
|
|
'status': -1,
|
|
|
|
|
|
|
|
'message': "Task not found.",
|
|
|
|
|
|
|
|
'error': {
|
|
|
|
|
|
|
|
'code': "task-not-found",
|
|
|
|
|
|
|
|
'info': "Task not found.",
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if request.get("caller") == "user":
|
|
|
|
|
|
|
|
user_id = request.get("user")
|
|
|
|
|
|
|
|
if task.user_id != user_id:
|
|
|
|
|
|
|
|
await ws.send_json({
|
|
|
|
|
|
|
|
'event': 'error',
|
|
|
|
|
|
|
|
'status': -1,
|
|
|
|
|
|
|
|
'message': "Permission denied.",
|
|
|
|
|
|
|
|
'error': {
|
|
|
|
|
|
|
|
'code': "permission-denied",
|
|
|
|
|
|
|
|
'info': "Permission denied.",
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if task.is_finished:
|
|
|
|
|
|
|
|
if task.error is not None:
|
|
|
|
|
|
|
|
await ws.send_json({
|
|
|
|
|
|
|
|
'event': 'error',
|
|
|
|
|
|
|
|
'status': -1,
|
|
|
|
|
|
|
|
'message': str(task.error),
|
|
|
|
|
|
|
|
'error': {
|
|
|
|
|
|
|
|
'code': "internal-error",
|
|
|
|
|
|
|
|
'info': str(task.error),
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
await ws.close()
|
|
|
|
|
|
|
|
elif task.result is not None:
|
|
|
|
|
|
|
|
await ws.send_json({
|
|
|
|
|
|
|
|
'event': 'connected',
|
|
|
|
|
|
|
|
'status': 1,
|
|
|
|
|
|
|
|
'outputed_message': "".join(task.chunks),
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
await ws.send_json({
|
|
|
|
|
|
|
|
'event': 'finished',
|
|
|
|
|
|
|
|
'status': 1,
|
|
|
|
|
|
|
|
'result': task.result
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
await ws.close()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
async def on_message(delta_message: str):
|
|
|
|
|
|
|
|
await ws.send_str("+" + delta_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def on_finished(result: ChatCompleteServiceResponse):
|
|
|
|
|
|
|
|
ignored_keys = ["message"]
|
|
|
|
|
|
|
|
response_result = {
|
|
|
|
|
|
|
|
"point_cost": task.point_cost,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for k, v in result.items():
|
|
|
|
|
|
|
|
if k not in ignored_keys:
|
|
|
|
|
|
|
|
response_result[k] = v
|
|
|
|
|
|
|
|
await ws.send_json({
|
|
|
|
|
|
|
|
'event': 'finished',
|
|
|
|
|
|
|
|
'status': 1,
|
|
|
|
|
|
|
|
'result': response_result
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
await ws.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def on_error(err: Exception):
|
|
|
|
|
|
|
|
await ws.send_json({
|
|
|
|
|
|
|
|
'event': 'error',
|
|
|
|
|
|
|
|
'status': -1,
|
|
|
|
|
|
|
|
'message': str(err),
|
|
|
|
|
|
|
|
'error': {
|
|
|
|
|
|
|
|
'code': "internal-error",
|
|
|
|
|
|
|
|
'info': str(err),
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
await ws.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task.on_message.append(on_message)
|
|
|
|
|
|
|
|
task.on_finished.append(on_finished)
|
|
|
|
|
|
|
|
task.on_error.append(on_error)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Send received message
|
|
|
|
|
|
|
|
await ws.send_json({
|
|
|
|
|
|
|
|
'event': 'connected',
|
|
|
|
|
|
|
|
'status': 1,
|
|
|
|
|
|
|
|
'outputed_message': "".join(task.chunks),
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
|
|
if ws.closed:
|
|
|
|
|
|
|
|
task.on_message.remove(on_message)
|
|
|
|
|
|
|
|
task.on_finished.remove(on_finished)
|
|
|
|
|
|
|
|
task.on_error.remove(on_error)
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|