From 6a74050b56e79b7f56a459e4b294b841cad450a9 Mon Sep 17 00:00:00 2001 From: Lex Lim Date: Mon, 29 May 2023 14:26:52 +0000 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0ChatComplete=E7=9A=84API?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/controller/ChatComplete.py | 232 ++++++++++++++++++--------------- api/route.py | 3 +- service/chat_complete.py | 2 +- utils/web.py | 6 +- 4 files changed, 133 insertions(+), 110 deletions(-) diff --git a/api/controller/ChatComplete.py b/api/controller/ChatComplete.py index 2fcbd14..5685c12 100644 --- a/api/controller/ChatComplete.py +++ b/api/controller/ChatComplete.py @@ -1,26 +1,103 @@ +from __future__ import annotations import asyncio import json +import sys import time import traceback from local import noawait -from typing import Optional -from aiohttp import WSMsgType, web +from typing import Optional, Callable, TypedDict +from aiohttp import web from sqlalchemy import select from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel from noawait import NoAwaitPool from service.chat_complete import ChatCompleteService from service.database import DatabaseService -from service.mediawiki_api import MediaWikiApi +from service.embedding_search import EmbeddingSearchArgs +from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException from service.tiktoken import TikTokenService import utils.web -class ChatCompleteTaskList: - def __init__(self, dbs: DatabaseService): - self.on_message = None +chat_complete_tasks: dict[str, ChatCompleteTask] = {} + +class ChatCompleteTask: + def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False): + self.task_id = utils.web.generate_uuid() + self.on_message: list[Callable] = [] + self.on_error: list[Callable] = [] self.chunks: list[str] = [] - async def run(): - pass + self.chat_complete_service: ChatCompleteService + self.chat_complete: ChatCompleteService + + self.dbs = dbs + self.user_id = user_id + self.page_title = page_title + self.is_system = is_system + + self.transatcion_id: Optional[str] = None + self.point_cost: int = 0 + + async def init(self, question: str, conversation_id: Optional[str] = None, + embedding_search: Optional[EmbeddingSearchArgs] = None): + self.tiktoken = await TikTokenService.create() + + self.mwapi = MediaWikiApi.create() + + self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title) + self.chat_complete = await self.chat_complete_service.__aenter__() + + if await self.chat_complete.page_index_exists(): + question_tokens = await self.tiktoken.get_tokens(question) + + extract_limit = embedding_search["limit"] or 10 + + self.transatcion_id: Optional[str] = None + self.point_cost: int = 0 + if not self.is_system: + usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete", question_tokens, extract_limit) + self.transatcion_id = usage_res.get("transaction_id") + self.point_cost = usage_res.get("point_cost") + + chat_res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id, + user_id=self.user_id, embedding_search=embedding_search) + + return chat_res + else: + await self._exit() + raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title) + + async def _on_message(self, delta_message: str): + for callback in self.on_message: + await callback(delta_message) + + async def _on_error(self, err: Exception): + for callback in self.on_error: + await callback(err) + + async def run(self): + try: + chat_res = self.chat_complete.finish_chat_complete(self._on_message) + + await self.chat_complete.set_latest_point_cost(self.point_cost) + + if self.transatcion_id: + result = await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"]) + except Exception as e: + err_msg = f"Error while processing chat complete request: {e}" + + print(err_msg, file=sys.stderr) + traceback.print_exc() + + if self.transatcion_id: + result = await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg) + + await self._on_error(e) + finally: + await self._exit() + + async def _exit(self): + await self.chat_complete_service.__aexit__(None, None, None) + del chat_complete_tasks[self.task_id] @noawait.wrap async def start(self): @@ -144,7 +221,7 @@ class ChatComplete: @staticmethod @utils.web.token_auth - async def chat_complete(request: web.Request): + async def start_chat_complete(request: web.Request): params = await utils.web.get_param(request, { "title": { "type": str, @@ -158,7 +235,7 @@ class ChatComplete: "type": int, "required": False, }, - "extra_limit": { + "extract_limit": { "type": int, "required": False, "default": 10, @@ -177,103 +254,44 @@ class ChatComplete: question = params.get("question") conversation_id = params.get("conversation_id") - extra_limit = params.get("extra_limit") + extract_limit = params.get("extract_limit") in_collection = params.get("in_collection") dbs = await DatabaseService.create(request.app) - tiktoken = await TikTokenService.create() - mwapi = MediaWikiApi.create() - if utils.web.is_websocket(request): - ws = web.WebSocketResponse() - await ws.prepare(request) - - try: - async with ChatCompleteService(dbs, page_title) as chat_complete_service: - if await chat_complete_service.page_index_exists(): - tokens = await tiktoken.get_tokens(question) - - transatcion_id = None - point_cost = 0 - if request.get("caller") == "user": - usage_res = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit) - transatcion_id = usage_res.get("transaction_id") - point_cost = usage_res.get("point_cost") - - async def on_message(text: str): - # Send message to client, start with "+" to indicate it's a message - # use json will make the package 10x larger - await ws.send_str("+" + text) - - async def on_extracted_doc(doc: list): - await ws.send_json({ - 'event': 'extract_doc', - 'status': 1, - 'doc': doc - }) - - try: - chat_res = await chat_complete_service \ - .prepare_chat_complete(question, conversation_id=conversation_id, user_id=user_id, embedding_search={ - "limit": extra_limit, - "in_collection": in_collection, - }) - await ws.send_json({ - 'event': 'done', - 'status': 1, - **chat_res, - }) - - await chat_complete_service.set_latest_point_cost(point_cost) - - if transatcion_id: - result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"]) - except Exception as e: - err_msg = f"Error while processing chat complete request: {e}" - traceback.print_exc() - - if not ws.closed: - await ws.send_json({ - 'event': 'error', - 'status': -1, - 'message': err_msg, - 'error': { - 'code': 'internal_error', - 'title': page_title, - }, - }) - if transatcion_id: - result = await mwapi.chat_complete_cancel_transaction(transatcion_id, error=err_msg) - else: - await ws.send_json({ - 'event': 'error', - 'status': -2, - 'message': "Page index not found.", - 'error': { - 'code': 'page_not_found', - 'title': page_title, - }, - }) - - # websocket closed - except Exception as e: - err_msg = f"Error while processing chat complete request: {e}" - traceback.print_exc() - - if not ws.closed: - await ws.send_json({ - 'event': 'error', - 'status': -1, - 'message': err_msg, - 'error': { - 'code': 'internal_error', - 'title': page_title, - }, - }) - finally: - if not ws.closed: - await ws.close() - else: - return await utils.web.api_response(-1, request=request, error={ - "code": "protocol-mismatch", - "message": "Protocol mismatch, websocket request expected." - }, http_status=400) + + try: + chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user") + init_res = await chat_complete_task.init(question, conversation_id=conversation_id, embedding_search={ + "limit": extract_limit, + "in_collection": in_collection, + }) + + chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task + + chat_complete_task.start() + + return utils.web.api_response(1, data={ + "question_tokens": init_res["question_tokens"], + "extract_doc": init_res["extract_doc"], + "task_id": chat_complete_task.task_id, + }, request=request) + except MediaWikiPageNotFoundException as e: + error_msg = "Page \"%s\" not found." % page_title + return await utils.web.api_response(-1, error={ + "code": "page-not-found", + "title": page_title, + "message": error_msg + }, http_status=404, request=request) + except Exception as e: + err_msg = f"Error while processing chat complete request: {e}" + traceback.print_exc() + + return await utils.web.api_response(-1, error={ + "code": "chat-complete-error", + "message": err_msg + }, http_status=500, request=request) + + @staticmethod + @utils.web.token_auth + async def chat_complete_stream(request: web.Request): + pass \ No newline at end of file diff --git a/api/route.py b/api/route.py index 06cdd1b..e7d0bd0 100644 --- a/api/route.py +++ b/api/route.py @@ -29,5 +29,6 @@ def init(app: web.Application): web.route('*', '/chatcomplete/conversation_chunks', ChatComplete.get_conversation_chunk_list), web.route('*', '/chatcomplete/conversation_chunk/{id:^\d+}', ChatComplete.get_conversation_chunk), - web.route('*', '/chatcomplete/message', ChatComplete.chat_complete), + web.route('*', '/chatcomplete/message', ChatComplete.start_chat_complete), + web.route('*', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream), ]) diff --git a/service/chat_complete.py b/service/chat_complete.py index cedca18..e2d6e54 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -83,7 +83,7 @@ class ChatCompleteService: async def prepare_chat_complete(self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None, question_tokens: Optional[int] = None, - embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServiceResponse: + embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse: if user_id is not None: user_id = int(user_id) diff --git a/utils/web.py b/utils/web.py index f3eaa61..508642b 100644 --- a/utils/web.py +++ b/utils/web.py @@ -4,6 +4,7 @@ from typing import Any, Optional, Dict from aiohttp import web import jwt import config +import uuid ParamRule = Dict[str, Any] @@ -86,7 +87,10 @@ async def api_response(status, data=None, error=None, warning=None, http_status= return web.json_response(ret, status=http_status) def is_websocket(request: web.Request): - return request.headers.get('Upgrade', '').lower() == 'websocket' + return request.headers.get('Upgrade', '').lower() == 'websocket' + +def generate_uuid(): + return str(uuid.uuid4()) # Auth decorator def token_auth(f):