import asyncio import json import time import traceback from aiohttp import WSMsgType, web from sqlalchemy import select from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel from service.chat_complete import ChatCompleteService from service.database import DatabaseService from service.mediawiki_api import MediaWikiApi from service.tiktoken import TikTokenService import utils.web class ChatCompleteWebSocketController: def __init__(self, request: web.Request): self.request = request self.ws = None self.db = None self.chat_complete = None self.closed = False self.refreshed_time = 0 async def run(self): self.ws = web.WebSocketResponse() await self.ws.prepare(self.request) self.refreshed_time = time.time() self.db = await DatabaseService.create(self.request.app) self.query = self.request.query if self.request.get("caller") == "user": user_id = self.request.get("user") else: user_id = self.query.get("user_id") title = self.query.get("title") # create heartbeat task asyncio.ensure_future(self._timeout_task()) async for msg in self.ws: if msg.type == WSMsgType.TEXT: try: data = json.loads(msg.data) event = data.get('event') self.refreshed_time = time.time() if event == 'chatcomplete': asyncio.ensure_future(self._chatcomplete(data)) if event == 'ping': await self.ws.send_json({ 'event': 'pong' }) except Exception as e: print(e) traceback.print_exc() await self.ws.send_json({ 'event': 'error', 'error': str(e) }) elif msg.type == WSMsgType.ERROR: print('ws connection closed with exception %s' % self.ws.exception()) async def _timeout_task(self): while not self.closed: if time.time() - self.refreshed_time > 30: self.closed = True await self.ws.close() return await asyncio.sleep(1) async def _chatcomplete(self, params: dict): question = params.get("question") conversation_id = params.get("conversation_id") class ChatComplete: @staticmethod @utils.web.token_auth async def get_conversation_chunk_list(request: web.Request): params = await utils.web.get_param(request, { "user_id": { "required": False, "type": int }, "conversation_id": { "required": True, "type": int } }) if request.get("caller") == "user": user_id = request.get("user") else: user_id = params.get("user_id") conversation_id = params.get("conversation_id") db = await DatabaseService.create(request.app) async with db.create_session() as session: stmt = select(ConversationModel).where( ConversationModel.id == conversation_id) conversation_data = await session.scalar(stmt) if conversation_data is None: return await utils.web.api_response(-1, error={ "code": "conversation-not-found", "message": "Conversation not found." }, http_status=404, request=request) if conversation_data.user_id != user_id: return await utils.web.api_response(-1, error={ "code": "permission-denied", "message": "Permission denied." }, http_status=403, request=request) stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \ .where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc()) conversation_chunk_result = await session.scalars(stmt) conversation_chunk_list = [] for result in conversation_chunk_result: conversation_chunk_list.append({ "id": result.id, "updated_at": result.updated_at }) return await utils.web.api_response(1, conversation_chunk_list, request=request) @staticmethod @utils.web.token_auth async def get_conversation_chunk(request: web.Request): params = await utils.web.get_param(request, { "user_id": { "required": False, "type": int, }, "chunk_id": { "required": True, "type": int, }, }) if request.get("caller") == "user": user_id = request.get("user") else: user_id = params.get("user_id") chunk_id = params.get("chunk_id") dbs = await DatabaseService.create(request.app) async with dbs.create_session() as session: stmt = select(ConversationChunkModel).where( ConversationChunkModel.id == chunk_id) conversation_data = await session.scalar(stmt) if conversation_data is None: return await utils.web.api_response(-1, error={ "code": "conversation-chunk-not-found", "message": "Conversation chunk not found." }, http_status=404, request=request) if conversation_data.conversation.user_id != user_id: return await utils.web.api_response(-1, error={ "code": "permission-denied", "message": "Permission denied." }, http_status=403, request=request) return await utils.web.api_response(1, conversation_data.__dict__, request=request) @staticmethod @utils.web.token_auth async def get_tokens(request: web.Request): params = await utils.web.get_param(request, { "question": { "type": str, "required": True } }) question = params.get("question") tiktoken = await TikTokenService.create() tokens = await tiktoken.get_tokens(question) return await utils.web.api_response(1, {"tokens": tokens}, request=request) @staticmethod @utils.web.token_auth async def chat_complete(request: web.Request): params = await utils.web.get_param(request, { "title": { "type": str, "required": True, }, "question": { "type": str, "required": True, }, "conversation_id": { "type": int, "required": False, }, "extra_limit": { "type": int, "required": False, "default": 10, }, "in_collection": { "type": bool, "required": False, "default": False, }, }) user_id = request.get("user") caller = request.get("caller") page_title = params.get("title") question = params.get("question") conversation_id = params.get("conversation_id") extra_limit = params.get("extra_limit") in_collection = params.get("in_collection") dbs = await DatabaseService.create(request.app) tiktoken = await TikTokenService.create() mwapi = MediaWikiApi.create() if utils.web.is_websocket(request): ws = web.WebSocketResponse() await ws.prepare(request) try: async with ChatCompleteService(dbs, page_title) as chat_complete_service: if await chat_complete_service.page_index_exists(): tokens = await tiktoken.get_tokens(question) transatcion_id = None if request.get("caller") == "user": transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit) async def on_message(text: str): # Send message to client, start with "+" to indicate it's a message # use json will make the package 10x larger await ws.send_str("+" + text) async def on_extracted_doc(doc: list): await ws.send_json({ 'event': 'extract_doc', 'status': 1, 'doc': doc }) try: chat_res = await chat_complete_service \ .chat_complete(question, on_message, on_extracted_doc, conversation_id=conversation_id, user_id=user_id, embedding_search={ "limit": extra_limit, "in_collection": in_collection, }) await ws.send_json({ 'event': 'done', 'status': 1, **chat_res, }) if transatcion_id: result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"]) except Exception as e: err_msg = f"Error while processing chat complete request: {e}" traceback.print_exc() if not ws.closed: await ws.send_json({ 'event': 'error', 'status': -1, 'message': err_msg, 'error': { 'code': 'internal_error', 'title': page_title, }, }) if transatcion_id: result = await mwapi.chat_complete_cancel_transaction(transatcion_id, error=err_msg) else: await ws.send_json({ 'event': 'error', 'status': -2, 'message': "Page index not found.", 'error': { 'code': 'page_not_found', 'title': page_title, }, }) # websocket closed except Exception as e: err_msg = f"Error while processing chat complete request: {e}" traceback.print_exc() if not ws.closed: await ws.send_json({ 'event': 'error', 'status': -1, 'message': err_msg, 'error': { 'code': 'internal_error', 'title': page_title, }, }) finally: if not ws.closed: await ws.close() else: return await utils.web.api_response(-1, request=request, error={ "code": "protocol-mismatch", "message": "Protocol mismatch, websocket request expected." }, http_status=400)