You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

203 lines
6.9 KiB
Python

from __future__ import annotations
import sys
import time
import traceback
from libs.config import Config
from server.model.chat_complete.bot_persona import BotPersonaHelper
from utils.local import noawait
from typing import Optional, Callable, Union
from service.chat_complete import (
ChatCompleteService,
ChatCompleteServicePrepareResponse,
ChatCompleteServiceResponse,
calculate_point_usage,
)
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
from service.tiktoken import TikTokenService
import utils.web
chat_complete_tasks: dict[str, ChatCompleteTask] = {}
class ChatCompleteTask:
@staticmethod
def get_by_id(task_id: str) -> Union[ChatCompleteTask, None]:
return chat_complete_tasks.get(task_id)
def __init__(
self, dbs: DatabaseService, user_id: int, page_title: str, is_system=False
):
self.task_id = utils.web.generate_uuid()
self.on_message: list[Callable] = []
self.on_finished: list[Callable] = []
self.on_error: list[Callable] = []
self.chunks: list[str] = []
self.chat_complete_service: ChatCompleteService
self.chat_complete: ChatCompleteService
self.dbs = dbs
self.user_id = user_id
self.page_title = page_title
self.is_system = is_system
self.transatcion_id: Optional[str] = None
self.point_usage: int = 0
self.is_finished = False
self.finished_time: Optional[float] = None
self.result: Optional[ChatCompleteServiceResponse] = None
self.error: Optional[Exception] = None
async def init(
self,
question: str,
bot_id: str,
conversation_id: Optional[str] = None,
edit_message_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None,
) -> ChatCompleteServicePrepareResponse:
self.tiktoken = await TikTokenService.create()
self.mwapi = MediaWikiApi.create()
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
self.chat_complete = await self.chat_complete_service.__aenter__()
try:
if await self.chat_complete.page_index_exists():
question_tokens = await self.tiktoken.get_tokens(question)
extract_limit = embedding_search["limit"] or 10
estimated_extract_tokens_per_doc = Config.get("estimated_extract_tokens_per_doc", 50, int)
predict_tokens = extract_limit + estimated_extract_tokens_per_doc
async with BotPersonaHelper(self.dbs) as bot_persona_helper:
bot_persona = await bot_persona_helper.find_by_bot_id(bot_id)
self.point_usage: int = calculate_point_usage(predict_tokens, bot_persona.cost_fixed,
bot_persona.cost_fixed_tokens, bot_persona.cost_per_token)
self.transatcion_id: Optional[str] = None
if not self.is_system:
usage_res = await self.mwapi.ai_toolbox_start_transaction(
self.user_id, "chatcomplete",
bot_id=bot_id,
tokens=predict_tokens,
point_usage=self.point_usage
)
self.transatcion_id = usage_res["transaction_id"]
res = await self.chat_complete.prepare_chat_complete(
question,
conversation_id=conversation_id,
user_id=self.user_id,
question_tokens=question_tokens,
edit_message_id=edit_message_id,
bot_id=bot_id,
embedding_search=embedding_search,
)
return res
else:
raise MediaWikiPageNotFoundException(
"Page %s not found." % self.page_title
)
except Exception as e:
await self.end()
raise e
async def _on_message(self, delta_message: str):
self.chunks.append(delta_message)
for callback in self.on_message:
try:
await callback(delta_message)
except Exception as e:
print(
"Error while processing on_message callback: %s" % e,
file=sys.stderr,
)
traceback.print_exc()
async def _on_finished(self):
for callback in self.on_finished:
try:
await callback(self.result)
except Exception as e:
print(
"Error while processing on_finished callback: %s" % e,
file=sys.stderr,
)
traceback.print_exc()
async def _on_error(self, err: Exception):
self.error = err
for callback in self.on_error:
try:
await callback(err)
except Exception as e:
print(
"Error while processing on_error callback: %s" % e, file=sys.stderr
)
traceback.print_exc()
async def run(self) -> ChatCompleteServiceResponse:
chat_complete_tasks[self.task_id] = self
try:
chat_res = await self.chat_complete.finish_chat_complete(self._on_message)
await self.chat_complete.set_latest_point_usage(self.point_usage)
self.result = chat_res
if self.transatcion_id:
await self.mwapi.ai_toolbox_end_transaction(
self.transatcion_id, self.point_usage # TODO: 根据实际使用Tokens扣除积分
)
await self._on_finished()
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
print(err_msg, file=sys.stderr)
traceback.print_exc()
if self.transatcion_id:
await self.mwapi.ai_toolbox_cancel_transaction(
self.transatcion_id, error=err_msg
)
await self._on_error(e)
finally:
await self.end()
async def end(self):
await self.chat_complete_service.__aexit__(None, None, None)
if self.task_id in chat_complete_tasks:
del chat_complete_tasks[self.task_id]
self.is_finished = True
self.finished_time = time.time()
TASK_EXPIRE_TIME = 60 * 10
async def chat_complete_task_gc():
now = time.time()
for task_id in chat_complete_tasks.keys():
task = chat_complete_tasks[task_id]
if (
task.is_finished
and task.finished_time is not None
and now > task.finished_time + TASK_EXPIRE_TIME
):
del chat_complete_tasks[task_id]
noawait.add_timer(chat_complete_task_gc, 60)