import asyncio import json import time import traceback from local import noawait from typing import Optional from aiohttp import WSMsgType, web from sqlalchemy import select from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel from noawait import NoAwaitPool from service.chat_complete import ChatCompleteService from service.database import DatabaseService from service.mediawiki_api import MediaWikiApi from service.tiktoken import TikTokenService import utils.web class ChatCompleteTaskList: def __init__(self, dbs: DatabaseService): self.on_message = None self.chunks: list[str] = [] async def run(): pass @noawait.wrap async def start(self): await self.run() class ChatComplete: @staticmethod @utils.web.token_auth async def get_conversation_chunk_list(request: web.Request): params = await utils.web.get_param(request, { "user_id": { "required": False, "type": int }, "conversation_id": { "required": True, "type": int } }) if request.get("caller") == "user": user_id = request.get("user") else: user_id = params.get("user_id") conversation_id = params.get("conversation_id") db = await DatabaseService.create(request.app) async with db.create_session() as session: stmt = select(ConversationModel).where( ConversationModel.id == conversation_id) conversation_data = await session.scalar(stmt) if conversation_data is None: return await utils.web.api_response(-1, error={ "code": "conversation-not-found", "message": "Conversation not found." }, http_status=404, request=request) if conversation_data.user_id != user_id: return await utils.web.api_response(-1, error={ "code": "permission-denied", "message": "Permission denied." }, http_status=403, request=request) stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \ .where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc()) conversation_chunk_result = await session.scalars(stmt) conversation_chunk_list = [] for result in conversation_chunk_result: conversation_chunk_list.append({ "id": result.id, "updated_at": result.updated_at }) return await utils.web.api_response(1, conversation_chunk_list, request=request) @staticmethod @utils.web.token_auth async def get_conversation_chunk(request: web.Request): params = await utils.web.get_param(request, { "user_id": { "required": False, "type": int, }, "chunk_id": { "required": True, "type": int, }, }) if request.get("caller") == "user": user_id = request.get("user") else: user_id = params.get("user_id") chunk_id = params.get("chunk_id") dbs = await DatabaseService.create(request.app) async with dbs.create_session() as session: stmt = select(ConversationChunkModel).where( ConversationChunkModel.id == chunk_id) conversation_data = await session.scalar(stmt) if conversation_data is None: return await utils.web.api_response(-1, error={ "code": "conversation-chunk-not-found", "message": "Conversation chunk not found." }, http_status=404, request=request) if conversation_data.conversation.user_id != user_id: return await utils.web.api_response(-1, error={ "code": "permission-denied", "message": "Permission denied." }, http_status=403, request=request) return await utils.web.api_response(1, conversation_data.__dict__, request=request) @staticmethod @utils.web.token_auth async def get_tokens(request: web.Request): params = await utils.web.get_param(request, { "question": { "type": str, "required": True } }) question = params.get("question") tiktoken = await TikTokenService.create() tokens = await tiktoken.get_tokens(question) return await utils.web.api_response(1, {"tokens": tokens}, request=request) @staticmethod @utils.web.token_auth async def chat_complete(request: web.Request): params = await utils.web.get_param(request, { "title": { "type": str, "required": True, }, "question": { "type": str, "required": True, }, "conversation_id": { "type": int, "required": False, }, "extra_limit": { "type": int, "required": False, "default": 10, }, "in_collection": { "type": bool, "required": False, "default": False, }, }) user_id = request.get("user") caller = request.get("caller") page_title = params.get("title") question = params.get("question") conversation_id = params.get("conversation_id") extra_limit = params.get("extra_limit") in_collection = params.get("in_collection") dbs = await DatabaseService.create(request.app) tiktoken = await TikTokenService.create() mwapi = MediaWikiApi.create() if utils.web.is_websocket(request): ws = web.WebSocketResponse() await ws.prepare(request) try: async with ChatCompleteService(dbs, page_title) as chat_complete_service: if await chat_complete_service.page_index_exists(): tokens = await tiktoken.get_tokens(question) transatcion_id = None point_cost = 0 if request.get("caller") == "user": usage_res = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit) transatcion_id = usage_res.get("transaction_id") point_cost = usage_res.get("point_cost") async def on_message(text: str): # Send message to client, start with "+" to indicate it's a message # use json will make the package 10x larger await ws.send_str("+" + text) async def on_extracted_doc(doc: list): await ws.send_json({ 'event': 'extract_doc', 'status': 1, 'doc': doc }) try: chat_res = await chat_complete_service \ .prepare_chat_complete(question, conversation_id=conversation_id, user_id=user_id, embedding_search={ "limit": extra_limit, "in_collection": in_collection, }) await ws.send_json({ 'event': 'done', 'status': 1, **chat_res, }) await chat_complete_service.set_latest_point_cost(point_cost) if transatcion_id: result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"]) except Exception as e: err_msg = f"Error while processing chat complete request: {e}" traceback.print_exc() if not ws.closed: await ws.send_json({ 'event': 'error', 'status': -1, 'message': err_msg, 'error': { 'code': 'internal_error', 'title': page_title, }, }) if transatcion_id: result = await mwapi.chat_complete_cancel_transaction(transatcion_id, error=err_msg) else: await ws.send_json({ 'event': 'error', 'status': -2, 'message': "Page index not found.", 'error': { 'code': 'page_not_found', 'title': page_title, }, }) # websocket closed except Exception as e: err_msg = f"Error while processing chat complete request: {e}" traceback.print_exc() if not ws.closed: await ws.send_json({ 'event': 'error', 'status': -1, 'message': err_msg, 'error': { 'code': 'internal_error', 'title': page_title, }, }) finally: if not ws.closed: await ws.close() else: return await utils.web.api_response(-1, request=request, error={ "code": "protocol-mismatch", "message": "Protocol mismatch, websocket request expected." }, http_status=400)