from __future__ import annotations import asyncio import time import traceback from libs.config import Config from server.controller.task.ChatCompleteTask import ChatCompleteTask from server.model.base import clone_model from server.model.chat_complete.bot_persona import BotPersonaHelper from server.model.toolbox_ui.conversation import ConversationHelper from utils.local import noawait from aiohttp import web from server.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse, calculate_point_usage from service.database import DatabaseService from service.mediawiki_api import MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException from service.tiktoken import TikTokenService import utils.web class ChatComplete: @staticmethod @utils.web.token_auth async def get_conversation_chunk_list(request: web.Request): params = await utils.web.get_param(request, { "user_id": { "required": False, "type": int }, "id": { "required": True, "type": int } }) if request.get("caller") == "user": user_id = request.get("user") else: user_id = params.get("user_id") conversation_id = params.get("id") db = await DatabaseService.create(request.app) async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper: conversation_info = await conversation_helper.find_by_id(conversation_id) if conversation_info is None: return await utils.web.api_response(-1, error={ "code": "conversation-not-found", "message": "Conversation not found." }, http_status=404, request=request) if conversation_info.user_id != user_id: return await utils.web.api_response(-1, error={ "code": "permission-denied", "message": "Permission denied." }, http_status=403, request=request) conversation_chunk_result = await conversation_chunk_helper.get_chunk_id_list(conversation_id) conversation_chunk_list = [] for result in conversation_chunk_result: conversation_chunk_list.append(result) return await utils.web.api_response(1, conversation_chunk_list, request=request) @staticmethod @utils.web.token_auth async def get_conversation_chunk(request: web.Request): params = await utils.web.get_param(request, { "user_id": { "required": False, "type": int, }, "chunk_id": { "required": True, "type": int, }, }) if request.get("caller") == "user": user_id = request.get("user") else: user_id = params.get("user_id") chunk_id = params.get("chunk_id") db = await DatabaseService.create(request.app) async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper: chunk_info = await conversation_chunk_helper.find_by_id(chunk_id) if chunk_info is None: return await utils.web.api_response(-1, error={ "code": "conversation-chunk-not-found", "message": "Conversation chunk not found." }, http_status=404, request=request) conversation_info = await conversation_helper.find_by_id(chunk_info.conversation_id) if conversation_info is not None and conversation_info.user_id != user_id: return await utils.web.api_response(-1, error={ "code": "permission-denied", "message": "Permission denied." }, http_status=403, request=request) chunk_dict = { "id": chunk_info.id, "conversation_id": chunk_info.conversation_id, "message_data": chunk_info.message_data, "tokens": chunk_info.tokens, "updated_at": chunk_info.updated_at, } return await utils.web.api_response(1, chunk_dict, request=request) @staticmethod @utils.web.token_auth async def fork_conversation(request: web.Request): params = await utils.web.get_param(request, { "user_id": { "required": False, "type": int, }, "id": { "required": True, "type": int, }, "message_id": { "required": False, "type": str }, "new_title": { "required": False, "type": str } }) if request.get("caller") == "user": user_id = request.get("user") else: user_id = params.get("user_id") conversation_id: int = params.get("id") packed_message_id: str = params.get("message_id") new_title = params.get("new_title") if packed_message_id is not None: (chunk_id, msg_id) = packed_message_id.split(",") chunk_id = int(chunk_id) else: chunk_id = None msg_id = None db = await DatabaseService.create(request.app) async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper: conversation_info = await conversation_helper.find_by_id(conversation_id) if conversation_info is None: return await utils.web.api_response(-1, error={ "code": "conversation-not-found", "message": "Conversation not found." }, http_status=404, request=request) if conversation_info.user_id != user_id: return await utils.web.api_response(-1, error={ "code": "permission-denied", "message": "Permission denied." }, http_status=403, request=request) # Clone selected chunk if chunk_id is not None: chunk_info = await conversation_chunk_helper.find_by_id(chunk_id) if chunk_info is None or chunk_info.conversation_id != conversation_id: return await utils.web.api_response(-1, error={ "code": "conversation-chunk-not-found", "message": "Conversation chunk not found." }, http_status=404, request=request) else: chunk_info = await conversation_chunk_helper.get_newest_chunk(conversation_id) new_conversation: ConversationModel = clone_model(conversation_info) if new_title is not None: new_conversation.title = new_title new_conversation = await conversation_helper.add(new_conversation) if chunk_info is not None: new_chunk: ConversationChunkModel = clone_model(chunk_info) new_chunk.conversation_id = new_conversation.id if msg_id is not None: # Remove message after selected message split_message_pos = None for i in range(0, len(new_chunk.message_data)): msg_data: dict = new_chunk.message_data[i] if msg_data.get("id") == msg_id: split_message_pos = i break new_chunk.message_data = new_chunk.message_data[0:split_message_pos + 1] new_chunk.message_data.insert(0, { "id": utils.web.generate_uuid(), "role": "notice", "type": "forked", "data": { "original_conversation_id": conversation_info.id, "original_title": conversation_info.title, } }) # Update conversation description last_assistant_message = None for msg in new_chunk.message_data: if msg["role"] == "assistant": last_assistant_message = msg if last_assistant_message is not None: new_conversation.description = last_assistant_message["content"][0:150] conversation_helper.update(new_conversation) new_chunk = await conversation_chunk_helper.add(new_chunk) return await utils.web.api_response(1, { "conversation_id": new_conversation.id, }, request=request) @staticmethod @utils.web.token_auth async def get_tokens(request: web.Request): params = await utils.web.get_param(request, { "question": { "type": str, "required": True } }) question = params.get("question") tiktoken = await TikTokenService.create() tokens = await tiktoken.get_tokens(question) return await utils.web.api_response(1, {"tokens": tokens}, request=request) @staticmethod @utils.web.token_auth async def get_point_usage(request: web.Request): params = await utils.web.get_param(request, { "question": { "type": str, "required": True, }, "bot_id": { "type": str, "required": True, }, "extract_limit": { "type": int, "required": False, "default": 10, }, }) user_id = request.get("user") question = params.get("question") bot_id = params.get("bot_id") tiktoken = await TikTokenService.create() db = await DatabaseService.create(request.app) tokens = await tiktoken.get_tokens(question) estimated_extract_tokens_per_doc = Config.get("estimated_extract_tokens_per_doc", 50, int) predict_tokens = tokens + estimated_extract_tokens_per_doc async with BotPersonaHelper(db) as bot_persona_helper: persona_info = await bot_persona_helper.find_by_bot_id(bot_id) if persona_info is None: return await utils.web.api_response(-1, error={ "code": "bot-not-found", "message": "Bot not found." }, http_status=404, request=request) point_usage = calculate_point_usage(predict_tokens, persona_info.cost_fixed, persona_info.cost_fixed_tokens, persona_info.cost_per_token) return await utils.web.api_response(1, { "point_usage": point_usage, "tokens": tokens, "predict_tokens": predict_tokens, }, request=request) @staticmethod @utils.web.token_auth async def get_persona_list(request: web.Request): params = await utils.web.get_param(request, { "page": { "type": int, "required": False, "default": 1, } }) page = params.get("page") db = await DatabaseService.create(request.app) async with BotPersonaHelper(db) as bot_persona_helper: persona_list = await bot_persona_helper.get_list(page=page) page_count = await bot_persona_helper.get_page_count() persona_data_list = [] for persona in persona_list: persona_data_list.append({ "id": persona.id, "bot_id": persona.bot_id, "bot_name": persona.bot_name, "bot_avatar": persona.bot_avatar, "bot_description": persona.bot_description, "model_id": persona.model_id, "model_name": persona.model_name, "cost_fixed": persona.cost_fixed }) return await utils.web.api_response(1, { "list": persona_data_list, "page_count": page_count, }, request=request) @staticmethod @utils.web.token_auth async def get_persona_info(request: web.Request): params = await utils.web.get_param(request, { "id": { "type": int, }, "bot_id": { "type": str, } }) persona_id = params.get("id") bot_id = params.get("bot_id") db = await DatabaseService.create(request.app) async with BotPersonaHelper(db) as bot_persona_helper: if persona_id is not None: persona_info = await bot_persona_helper.find_by_id(persona_id) elif bot_id is not None: persona_info = await bot_persona_helper.find_by_bot_id(bot_id) else: return await utils.web.api_response(-1, error={ "code": "invalid-params", "message": "Invalid params. Please specify id or bot_id." }, http_status=400, request=request) persona_info_res = {} for key, value in persona_info.__dict__.items(): if not key.startswith("_"): persona_info_res[key] = value return await utils.web.api_response(1, persona_info_res, request=request) @staticmethod @utils.web.token_auth async def start_chat_complete(request: web.Request): params = await utils.web.get_param(request, { "title": { "type": str, "required": True, }, "question": { "type": str, "required": True, }, "conversation_id": { "type": int, "required": False, }, "bot_id": { "type": str, "required": True, }, "extract_limit": { "type": int, "required": False, "default": 10, }, "in_collection": { "type": bool, "required": False, "default": False, }, "edit_message_id": { "type": str, "required": False, }, }) user_id = request.get("user") caller = request.get("caller") page_title = params.get("title") question = params.get("question") conversation_id = params.get("conversation_id") bot_id = params.get("bot_id") extract_limit = params.get("extract_limit") in_collection = params.get("in_collection") edit_message_id = params.get("edit_message_id") dbs = await DatabaseService.create(request.app) try: chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user") init_res = await chat_complete_task.init(question, conversation_id=conversation_id, edit_message_id=edit_message_id, bot_id=bot_id, embedding_search={ "limit": extract_limit, "in_collection": in_collection, }) noawait.add_task(chat_complete_task.run()) return await utils.web.api_response(1, data={ "conversation_id": init_res["conversation_id"], "chunk_id": init_res["chunk_id"], "question_tokens": init_res["question_tokens"], "extract_doc": init_res["extract_doc"], "task_id": chat_complete_task.task_id, }, request=request) except MediaWikiPageNotFoundException as e: error_msg = "Page \"%s\" not found." % page_title return await utils.web.api_response(-1, error={ "code": "page-not-found", "title": page_title, "message": error_msg }, http_status=404, request=request) except MediaWikiUserNoEnoughPointsException as e: error_msg = "Does not have enough points." % user_id return await utils.web.api_response(-1, error={ "code": "no-enough-points", "message": error_msg }, http_status=403, request=request) except ChatCompleteQuestionTooLongException as e: error_msg = "Question too long." return await utils.web.api_response(-1, error={ "code": "question-too-long", "limit": e.tokens_limit, "current": e.tokens_current, "message": error_msg }, http_status=400, request=request) except Exception as e: err_msg = f"Error while processing chat complete request: {e}" traceback.print_exc() return await utils.web.api_response(-1, error={ "code": "internal-server-error", "message": err_msg }, http_status=500, request=request) @staticmethod @utils.web.token_auth async def chat_complete_stream(request: web.Request): if not utils.web.is_websocket(request): return await utils.web.api_response(-1, error={ "code": "websocket-required", "message": "This API only accept websocket connection." }, http_status=400, request=request) params = await utils.web.get_param(request, { "task_id": { "type": str, "required": True, } }) ws = web.WebSocketResponse() await ws.prepare(request) task_id = params.get("task_id") task = ChatCompleteTask.get_by_id(task_id) if task is None: await ws.send_json({ 'event': 'error', 'status': -1, 'message': "Task not found.", 'error': { 'code': "task-not-found", 'info': "Task not found.", }, }) return if request.get("caller") == "user": user_id = request.get("user") if task.user_id != user_id: await ws.send_json({ 'event': 'error', 'status': -1, 'message': "Permission denied.", 'error': { 'code': "permission-denied", 'info': "Permission denied.", }, }) return if task.is_finished: if task.error is not None: await ws.send_json({ 'event': 'error', 'status': -1, 'message': str(task.error), 'error': { 'code': "internal-server-error", 'info': str(task.error), }, }) await ws.close() elif task.result is not None: await ws.send_json({ 'event': 'connected', 'status': 1, 'outputed_message': "".join(task.chunks), }) await ws.send_json({ 'event': 'finished', 'status': 1, 'result': task.result }) await ws.close() else: async def on_closed(): task.on_message.remove(on_message) task.on_finished.remove(on_finished) task.on_error.remove(on_error) async def on_message(delta_message: str): try: await ws.send_str("+" + delta_message) except ConnectionResetError: await on_closed() async def on_finished(result: ChatCompleteServiceResponse): try: ignored_keys = ["message"] response_result = { "point_usage": task.point_usage, } for k, v in result.items(): if k not in ignored_keys: response_result[k] = v await ws.send_json({ 'event': 'finished', 'status': 1, 'result': response_result }) await ws.close() except ConnectionResetError: await on_closed() async def on_error(err: Exception): try: await ws.send_json({ 'event': 'error', 'status': -1, 'message': str(err), 'error': { 'code': "internal-server-error", 'info': str(err), }, }) await ws.close() except ConnectionResetError: await on_closed() task.on_message.append(on_message) task.on_finished.append(on_finished) task.on_error.append(on_error) # Send received message await ws.send_json({ 'event': 'connected', 'status': 1, 'outputed_message': "".join(task.chunks), }) last_heartbeat = time.time() while True: current_time = time.time() if ws.closed or task.is_finished: await on_closed() break if current_time - last_heartbeat >= 15: try: await ws.ping('{"event":"ping"}'.encode('utf-8')) last_heartbeat = current_time except ConnectionResetError: await on_closed() break await asyncio.sleep(0.1)