|
|
|
@ -4,7 +4,11 @@ import time
|
|
|
|
|
import traceback
|
|
|
|
|
from local import noawait
|
|
|
|
|
from typing import Optional, Callable, Union
|
|
|
|
|
from service.chat_complete import ChatCompleteService, ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse
|
|
|
|
|
from service.chat_complete import (
|
|
|
|
|
ChatCompleteService,
|
|
|
|
|
ChatCompleteServicePrepareResponse,
|
|
|
|
|
ChatCompleteServiceResponse,
|
|
|
|
|
)
|
|
|
|
|
from service.database import DatabaseService
|
|
|
|
|
from service.embedding_search import EmbeddingSearchArgs
|
|
|
|
|
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
|
|
|
|
@ -13,12 +17,15 @@ import utils.web
|
|
|
|
|
|
|
|
|
|
chat_complete_tasks: dict[str, ChatCompleteTask] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatCompleteTask:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_by_id(task_id: str) -> Union[ChatCompleteTask, None]:
|
|
|
|
|
return chat_complete_tasks.get(task_id)
|
|
|
|
|
|
|
|
|
|
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
|
|
|
|
|
def __init__(
|
|
|
|
|
self, dbs: DatabaseService, user_id: int, page_title: str, is_system=False
|
|
|
|
|
):
|
|
|
|
|
self.task_id = utils.web.generate_uuid()
|
|
|
|
|
self.on_message: list[Callable] = []
|
|
|
|
|
self.on_finished: list[Callable] = []
|
|
|
|
@ -41,15 +48,21 @@ class ChatCompleteTask:
|
|
|
|
|
self.result: Optional[ChatCompleteServiceResponse] = None
|
|
|
|
|
self.error: Optional[Exception] = None
|
|
|
|
|
|
|
|
|
|
async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None,
|
|
|
|
|
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse:
|
|
|
|
|
async def init(
|
|
|
|
|
self,
|
|
|
|
|
question: str,
|
|
|
|
|
conversation_id: Optional[str] = None,
|
|
|
|
|
edit_message_id: Optional[str] = None,
|
|
|
|
|
bot_id: Optional[str] = None,
|
|
|
|
|
embedding_search: Optional[EmbeddingSearchArgs] = None,
|
|
|
|
|
) -> ChatCompleteServicePrepareResponse:
|
|
|
|
|
self.tiktoken = await TikTokenService.create()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.mwapi = MediaWikiApi.create()
|
|
|
|
|
|
|
|
|
|
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
|
|
|
|
|
self.chat_complete = await self.chat_complete_service.__aenter__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if await self.chat_complete.page_index_exists():
|
|
|
|
|
question_tokens = await self.tiktoken.get_tokens(question)
|
|
|
|
@ -59,29 +72,41 @@ class ChatCompleteTask:
|
|
|
|
|
self.transatcion_id: Optional[str] = None
|
|
|
|
|
self.point_cost: int = 0
|
|
|
|
|
if not self.is_system:
|
|
|
|
|
usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete",
|
|
|
|
|
question_tokens, extract_limit)
|
|
|
|
|
usage_res = await self.mwapi.ai_toolbox_start_transaction(
|
|
|
|
|
self.user_id, "chatcomplete", question_tokens, extract_limit
|
|
|
|
|
)
|
|
|
|
|
self.transatcion_id = usage_res["transaction_id"]
|
|
|
|
|
self.point_cost = usage_res["point_cost"]
|
|
|
|
|
|
|
|
|
|
res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
|
|
|
|
|
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
|
|
|
|
|
|
|
|
|
|
res = await self.chat_complete.prepare_chat_complete(
|
|
|
|
|
question,
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
user_id=self.user_id,
|
|
|
|
|
edit_message_id=edit_message_id,
|
|
|
|
|
bot_id=bot_id,
|
|
|
|
|
embedding_search=embedding_search,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
else:
|
|
|
|
|
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
|
|
|
|
|
raise MediaWikiPageNotFoundException(
|
|
|
|
|
"Page %s not found." % self.page_title
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
await self.end()
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
async def _on_message(self, delta_message: str):
|
|
|
|
|
self.chunks.append(delta_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for callback in self.on_message:
|
|
|
|
|
try:
|
|
|
|
|
await callback(delta_message)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print("Error while processing on_message callback: %s" % e, file=sys.stderr)
|
|
|
|
|
print(
|
|
|
|
|
"Error while processing on_message callback: %s" % e,
|
|
|
|
|
file=sys.stderr,
|
|
|
|
|
)
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
async def _on_finished(self):
|
|
|
|
@ -89,7 +114,10 @@ class ChatCompleteTask:
|
|
|
|
|
try:
|
|
|
|
|
await callback(self.result)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print("Error while processing on_finished callback: %s" % e, file=sys.stderr)
|
|
|
|
|
print(
|
|
|
|
|
"Error while processing on_finished callback: %s" % e,
|
|
|
|
|
file=sys.stderr,
|
|
|
|
|
)
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
async def _on_error(self, err: Exception):
|
|
|
|
@ -98,7 +126,9 @@ class ChatCompleteTask:
|
|
|
|
|
try:
|
|
|
|
|
await callback(err)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print("Error while processing on_error callback: %s" % e, file=sys.stderr)
|
|
|
|
|
print(
|
|
|
|
|
"Error while processing on_error callback: %s" % e, file=sys.stderr
|
|
|
|
|
)
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
async def run(self) -> ChatCompleteServiceResponse:
|
|
|
|
@ -111,7 +141,9 @@ class ChatCompleteTask:
|
|
|
|
|
self.result = chat_res
|
|
|
|
|
|
|
|
|
|
if self.transatcion_id:
|
|
|
|
|
await self.mwapi.ai_toolbox_end_transaction(self.transatcion_id, chat_res["total_tokens"])
|
|
|
|
|
await self.mwapi.ai_toolbox_end_transaction(
|
|
|
|
|
self.transatcion_id, chat_res["total_tokens"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
await self._on_finished()
|
|
|
|
|
except Exception as e:
|
|
|
|
@ -121,12 +153,14 @@ class ChatCompleteTask:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
if self.transatcion_id:
|
|
|
|
|
await self.mwapi.ai_toolbox_cancel_transaction(self.transatcion_id, error=err_msg)
|
|
|
|
|
await self.mwapi.ai_toolbox_cancel_transaction(
|
|
|
|
|
self.transatcion_id, error=err_msg
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
await self._on_error(e)
|
|
|
|
|
finally:
|
|
|
|
|
await self.end()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def end(self):
|
|
|
|
|
await self.chat_complete_service.__aexit__(None, None, None)
|
|
|
|
|
if self.task_id in chat_complete_tasks:
|
|
|
|
@ -134,13 +168,20 @@ class ChatCompleteTask:
|
|
|
|
|
self.is_finished = True
|
|
|
|
|
self.finished_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TASK_EXPIRE_TIME = 60 * 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def chat_complete_task_gc():
|
|
|
|
|
now = time.time()
|
|
|
|
|
for task_id in chat_complete_tasks.keys():
|
|
|
|
|
task = chat_complete_tasks[task_id]
|
|
|
|
|
if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME:
|
|
|
|
|
if (
|
|
|
|
|
task.is_finished
|
|
|
|
|
and task.finished_time is not None
|
|
|
|
|
and now > task.finished_time + TASK_EXPIRE_TIME
|
|
|
|
|
):
|
|
|
|
|
del chat_complete_tasks[task_id]
|
|
|
|
|
|
|
|
|
|
noawait.add_timer(chat_complete_task_gc, 60)
|
|
|
|
|
|
|
|
|
|
noawait.add_timer(chat_complete_task_gc, 60)
|
|
|
|
|