from __future__ import annotations import sys import time import traceback from utils.local import noawait from typing import Optional, Callable, Union from service.chat_complete import ( ChatCompleteService, ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse, ) from service.database import DatabaseService from service.embedding_search import EmbeddingSearchArgs from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException from service.tiktoken import TikTokenService import utils.web chat_complete_tasks: dict[str, ChatCompleteTask] = {} class ChatCompleteTask: @staticmethod def get_by_id(task_id: str) -> Union[ChatCompleteTask, None]: return chat_complete_tasks.get(task_id) def __init__( self, dbs: DatabaseService, user_id: int, page_title: str, is_system=False ): self.task_id = utils.web.generate_uuid() self.on_message: list[Callable] = [] self.on_finished: list[Callable] = [] self.on_error: list[Callable] = [] self.chunks: list[str] = [] self.chat_complete_service: ChatCompleteService self.chat_complete: ChatCompleteService self.dbs = dbs self.user_id = user_id self.page_title = page_title self.is_system = is_system self.transatcion_id: Optional[str] = None self.point_cost: int = 0 self.is_finished = False self.finished_time: Optional[float] = None self.result: Optional[ChatCompleteServiceResponse] = None self.error: Optional[Exception] = None async def init( self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None, bot_id: Optional[str] = None, embedding_search: Optional[EmbeddingSearchArgs] = None, ) -> ChatCompleteServicePrepareResponse: self.tiktoken = await TikTokenService.create() self.mwapi = MediaWikiApi.create() self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title) self.chat_complete = await self.chat_complete_service.__aenter__() try: if await self.chat_complete.page_index_exists(): question_tokens = await self.tiktoken.get_tokens(question) extract_limit = embedding_search["limit"] or 10 self.transatcion_id: Optional[str] = None self.point_cost: int = 0 if not self.is_system: usage_res = await self.mwapi.ai_toolbox_start_transaction( self.user_id, "chatcomplete", question_tokens, extract_limit ) self.transatcion_id = usage_res["transaction_id"] self.point_cost = usage_res["point_cost"] res = await self.chat_complete.prepare_chat_complete( question, conversation_id=conversation_id, user_id=self.user_id, edit_message_id=edit_message_id, bot_id=bot_id, embedding_search=embedding_search, ) return res else: raise MediaWikiPageNotFoundException( "Page %s not found." % self.page_title ) except Exception as e: await self.end() raise e async def _on_message(self, delta_message: str): self.chunks.append(delta_message) for callback in self.on_message: try: await callback(delta_message) except Exception as e: print( "Error while processing on_message callback: %s" % e, file=sys.stderr, ) traceback.print_exc() async def _on_finished(self): for callback in self.on_finished: try: await callback(self.result) except Exception as e: print( "Error while processing on_finished callback: %s" % e, file=sys.stderr, ) traceback.print_exc() async def _on_error(self, err: Exception): self.error = err for callback in self.on_error: try: await callback(err) except Exception as e: print( "Error while processing on_error callback: %s" % e, file=sys.stderr ) traceback.print_exc() async def run(self) -> ChatCompleteServiceResponse: chat_complete_tasks[self.task_id] = self try: chat_res = await self.chat_complete.finish_chat_complete(self._on_message) await self.chat_complete.set_latest_point_cost(self.point_cost) self.result = chat_res if self.transatcion_id: await self.mwapi.ai_toolbox_end_transaction( self.transatcion_id, chat_res["total_tokens"] ) await self._on_finished() except Exception as e: err_msg = f"Error while processing chat complete request: {e}" print(err_msg, file=sys.stderr) traceback.print_exc() if self.transatcion_id: await self.mwapi.ai_toolbox_cancel_transaction( self.transatcion_id, error=err_msg ) await self._on_error(e) finally: await self.end() async def end(self): await self.chat_complete_service.__aexit__(None, None, None) if self.task_id in chat_complete_tasks: del chat_complete_tasks[self.task_id] self.is_finished = True self.finished_time = time.time() TASK_EXPIRE_TIME = 60 * 10 async def chat_complete_task_gc(): now = time.time() for task_id in chat_complete_tasks.keys(): task = chat_complete_tasks[task_id] if ( task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME ): del chat_complete_tasks[task_id] noawait.add_timer(chat_complete_task_gc, 60)