isekai-toolkit/server/controller/task/ChatCompleteTask.py

166 lines
5.8 KiB
Python

from __future__ import annotations
import sys
import time
import traceback
from events.chat_complete_event import ChatCompleteEvent, ChatCompleteTaskEvent
from libs.config import Config
from server.model.chat_complete.bot_persona import BotPersonaHelper
from type_defs.chat_complete_task import ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse
from utils.local import noawait
from typing import Optional, Callable, Union
from service.chat_complete import (
ChatCompleteService,
calculate_point_usage,
)
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
from service.tiktoken import TikTokenService
import utils.web
chat_complete_tasks: dict[str, ChatCompleteTask] = {}
class ChatCompleteTask:
@staticmethod
def get_by_id(task_id: str) -> Union[ChatCompleteTask, None]:
return chat_complete_tasks.get(task_id)
def __init__(
self, dbs: DatabaseService, user_id: int, page_title: str, is_system=False
):
self.task_id = utils.web.generate_uuid()
self.events: ChatCompleteTaskEvent = ChatCompleteTaskEvent()
self.chunks: list[str] = []
self.chat_complete_service: ChatCompleteService
self.chat_complete: ChatCompleteService
self.dbs = dbs
self.user_id = user_id
self.page_title = page_title
self.is_system = is_system
self.transatcion_id: Optional[str] = None
self.point_usage: int = 0
self.is_finished = False
self.finished_time: Optional[float] = None
self.result: Optional[ChatCompleteServiceResponse] = None
self.error: Optional[Exception] = None
async def init(
self,
question: str,
bot_id: str,
conversation_id: Optional[str] = None,
edit_message_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None,
) -> ChatCompleteServicePrepareResponse:
self.tiktoken = await TikTokenService.create()
self.mwapi = MediaWikiApi.create()
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
self.chat_complete = await self.chat_complete_service.__aenter__()
try:
if await self.chat_complete.page_index_exists():
question_tokens = await self.tiktoken.get_tokens(question)
extract_limit = embedding_search["limit"] or 10
estimated_extract_tokens_per_doc = Config.get("estimated_extract_tokens_per_doc", 50, int)
predict_tokens = extract_limit + estimated_extract_tokens_per_doc
async with BotPersonaHelper(self.dbs) as bot_persona_helper:
bot_persona = await bot_persona_helper.find_by_bot_id(bot_id)
self.point_usage: int = calculate_point_usage(predict_tokens, bot_persona.cost_fixed,
bot_persona.cost_fixed_tokens, bot_persona.cost_per_token)
self.transatcion_id: Optional[str] = None
if not self.is_system:
usage_res = await self.mwapi.ai_toolbox_start_transaction(
self.user_id, "chatcomplete",
bot_id=bot_id,
tokens=predict_tokens,
point_usage=self.point_usage
)
self.transatcion_id = usage_res["transaction_id"]
res = await self.chat_complete.prepare_chat_complete(
question,
conversation_id=conversation_id,
user_id=self.user_id,
question_tokens=question_tokens,
edit_message_id=edit_message_id,
bot_id=bot_id,
embedding_search=embedding_search,
)
return res
else:
raise MediaWikiPageNotFoundException(
"Page %s not found." % self.page_title
)
except Exception as e:
await self.end()
raise e
async def run(self) -> ChatCompleteServiceResponse:
chat_complete_tasks[self.task_id] = self
try:
chat_res = await self.chat_complete.finish_chat_complete(event_container=self.events)
await self.chat_complete.set_latest_point_usage(self.point_usage)
self.result = chat_res
if self.transatcion_id:
await self.mwapi.ai_toolbox_end_transaction(
self.transatcion_id, self.point_usage # TODO: 根据实际使用Tokens扣除积分
)
await self.events.emit_finished(chat_res)
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
print(err_msg, file=sys.stderr)
traceback.print_exc()
if self.transatcion_id:
await self.mwapi.ai_toolbox_cancel_transaction(
self.transatcion_id, error=err_msg
)
await self.events.emit_error(e)
finally:
await self.end()
async def end(self):
await self.chat_complete_service.__aexit__(None, None, None)
if self.task_id in chat_complete_tasks:
del chat_complete_tasks[self.task_id]
self.is_finished = True
self.finished_time = time.time()
TASK_EXPIRE_TIME = 60 * 10
async def chat_complete_task_gc():
now = time.time()
for task_id in chat_complete_tasks.keys():
task = chat_complete_tasks[task_id]
if (
task.is_finished
and task.finished_time is not None
and now > task.finished_time + TASK_EXPIRE_TIME
):
del chat_complete_tasks[task_id]
noawait.add_timer(chat_complete_task_gc, 60)