from __future__ import annotations import asyncio import json import sys import time import traceback from local import noawait from typing import Optional, Callable, TypedDict from aiohttp import web from sqlalchemy import select from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel from noawait import NoAwaitPool from service.chat_complete import ChatCompleteService from service.database import DatabaseService from service.embedding_search import EmbeddingSearchArgs from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException from service.tiktoken import TikTokenService import utils.web chat_complete_tasks: dict[str, ChatCompleteTask] = {} class ChatCompleteTask: def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False): self.task_id = utils.web.generate_uuid() self.on_message: list[Callable] = [] self.on_error: list[Callable] = [] self.chunks: list[str] = [] self.chat_complete_service: ChatCompleteService self.chat_complete: ChatCompleteService self.dbs = dbs self.user_id = user_id self.page_title = page_title self.is_system = is_system self.transatcion_id: Optional[str] = None self.point_cost: int = 0 async def init(self, question: str, conversation_id: Optional[str] = None, embedding_search: Optional[EmbeddingSearchArgs] = None): self.tiktoken = await TikTokenService.create() self.mwapi = MediaWikiApi.create() self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title) self.chat_complete = await self.chat_complete_service.__aenter__() if await self.chat_complete.page_index_exists(): question_tokens = await self.tiktoken.get_tokens(question) extract_limit = embedding_search["limit"] or 10 self.transatcion_id: Optional[str] = None self.point_cost: int = 0 if not self.is_system: usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete", question_tokens, extract_limit) self.transatcion_id = usage_res.get("transaction_id") self.point_cost = usage_res.get("point_cost") chat_res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id, user_id=self.user_id, embedding_search=embedding_search) return chat_res else: await self._exit() raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title) async def _on_message(self, delta_message: str): for callback in self.on_message: await callback(delta_message) async def _on_error(self, err: Exception): for callback in self.on_error: await callback(err) async def run(self): try: chat_res = self.chat_complete.finish_chat_complete(self._on_message) await self.chat_complete.set_latest_point_cost(self.point_cost) if self.transatcion_id: result = await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"]) except Exception as e: err_msg = f"Error while processing chat complete request: {e}" print(err_msg, file=sys.stderr) traceback.print_exc() if self.transatcion_id: result = await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg) await self._on_error(e) finally: await self._exit() async def _exit(self): await self.chat_complete_service.__aexit__(None, None, None) del chat_complete_tasks[self.task_id] @noawait.wrap async def start(self): await self.run() class ChatComplete: @staticmethod @utils.web.token_auth async def get_conversation_chunk_list(request: web.Request): params = await utils.web.get_param(request, { "user_id": { "required": False, "type": int }, "conversation_id": { "required": True, "type": int } }) if request.get("caller") == "user": user_id = request.get("user") else: user_id = params.get("user_id") conversation_id = params.get("conversation_id") db = await DatabaseService.create(request.app) async with db.create_session() as session: stmt = select(ConversationModel).where( ConversationModel.id == conversation_id) conversation_data = await session.scalar(stmt) if conversation_data is None: return await utils.web.api_response(-1, error={ "code": "conversation-not-found", "message": "Conversation not found." }, http_status=404, request=request) if conversation_data.user_id != user_id: return await utils.web.api_response(-1, error={ "code": "permission-denied", "message": "Permission denied." }, http_status=403, request=request) stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \ .where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc()) conversation_chunk_result = await session.scalars(stmt) conversation_chunk_list = [] for result in conversation_chunk_result: conversation_chunk_list.append({ "id": result.id, "updated_at": result.updated_at }) return await utils.web.api_response(1, conversation_chunk_list, request=request) @staticmethod @utils.web.token_auth async def get_conversation_chunk(request: web.Request): params = await utils.web.get_param(request, { "user_id": { "required": False, "type": int, }, "chunk_id": { "required": True, "type": int, }, }) if request.get("caller") == "user": user_id = request.get("user") else: user_id = params.get("user_id") chunk_id = params.get("chunk_id") dbs = await DatabaseService.create(request.app) async with dbs.create_session() as session: stmt = select(ConversationChunkModel).where( ConversationChunkModel.id == chunk_id) conversation_data = await session.scalar(stmt) if conversation_data is None: return await utils.web.api_response(-1, error={ "code": "conversation-chunk-not-found", "message": "Conversation chunk not found." }, http_status=404, request=request) if conversation_data.conversation.user_id != user_id: return await utils.web.api_response(-1, error={ "code": "permission-denied", "message": "Permission denied." }, http_status=403, request=request) return await utils.web.api_response(1, conversation_data.__dict__, request=request) @staticmethod @utils.web.token_auth async def get_tokens(request: web.Request): params = await utils.web.get_param(request, { "question": { "type": str, "required": True } }) question = params.get("question") tiktoken = await TikTokenService.create() tokens = await tiktoken.get_tokens(question) return await utils.web.api_response(1, {"tokens": tokens}, request=request) @staticmethod @utils.web.token_auth async def start_chat_complete(request: web.Request): params = await utils.web.get_param(request, { "title": { "type": str, "required": True, }, "question": { "type": str, "required": True, }, "conversation_id": { "type": int, "required": False, }, "extract_limit": { "type": int, "required": False, "default": 10, }, "in_collection": { "type": bool, "required": False, "default": False, }, }) user_id = request.get("user") caller = request.get("caller") page_title = params.get("title") question = params.get("question") conversation_id = params.get("conversation_id") extract_limit = params.get("extract_limit") in_collection = params.get("in_collection") dbs = await DatabaseService.create(request.app) try: chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user") init_res = await chat_complete_task.init(question, conversation_id=conversation_id, embedding_search={ "limit": extract_limit, "in_collection": in_collection, }) chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task chat_complete_task.start() return utils.web.api_response(1, data={ "question_tokens": init_res["question_tokens"], "extract_doc": init_res["extract_doc"], "task_id": chat_complete_task.task_id, }, request=request) except MediaWikiPageNotFoundException as e: error_msg = "Page \"%s\" not found." % page_title return await utils.web.api_response(-1, error={ "code": "page-not-found", "title": page_title, "message": error_msg }, http_status=404, request=request) except Exception as e: err_msg = f"Error while processing chat complete request: {e}" traceback.print_exc() return await utils.web.api_response(-1, error={ "code": "chat-complete-error", "message": err_msg }, http_status=500, request=request) @staticmethod @utils.web.token_auth async def chat_complete_stream(request: web.Request): pass