diff --git a/api/controller/ChatComplete.py b/api/controller/ChatComplete.py index 134406b..b4b817f 100644 --- a/api/controller/ChatComplete.py +++ b/api/controller/ChatComplete.py @@ -1,6 +1,5 @@ from __future__ import annotations import asyncio -import sys import time import traceback from api.controller.task.ChatCompleteTask import ChatCompleteTask @@ -8,14 +7,10 @@ from api.model.base import clone_model from api.model.chat_complete.bot_persona import BotPersonaHelper from api.model.toolkit_ui.conversation import ConversationHelper from local import noawait -from typing import Optional, Callable, TypedDict from aiohttp import web -from sqlalchemy import select from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel -from noawait import NoAwaitPool -from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteService, ChatCompleteServiceResponse +from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse from service.database import DatabaseService -from service.embedding_search import EmbeddingSearchArgs from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException from service.tiktoken import TikTokenService import utils.web @@ -531,37 +526,51 @@ class ChatComplete: }) await ws.close() else: + async def on_closed(): + task.on_message.remove(on_message) + task.on_finished.remove(on_finished) + task.on_error.remove(on_error) + async def on_message(delta_message: str): - await ws.send_str("+" + delta_message) + try: + await ws.send_str("+" + delta_message) + except ConnectionResetError: + on_closed() async def on_finished(result: ChatCompleteServiceResponse): - ignored_keys = ["message"] - response_result = { - "point_cost": task.point_cost, - } - for k, v in result.items(): - if k not in ignored_keys: - response_result[k] = v - await ws.send_json({ - 'event': 'finished', - 'status': 1, - 'result': response_result - }) - - await ws.close() + try: + ignored_keys = ["message"] + response_result = { + "point_cost": task.point_cost, + } + for k, v in result.items(): + if k not in ignored_keys: + response_result[k] = v + await ws.send_json({ + 'event': 'finished', + 'status': 1, + 'result': response_result + }) + + await ws.close() + except ConnectionResetError: + on_closed() async def on_error(err: Exception): - await ws.send_json({ - 'event': 'error', - 'status': -1, - 'message': str(err), - 'error': { - 'code': "internal-server-error", - 'info': str(err), - }, - }) - - await ws.close() + try: + await ws.send_json({ + 'event': 'error', + 'status': -1, + 'message': str(err), + 'error': { + 'code': "internal-server-error", + 'info': str(err), + }, + }) + + await ws.close() + except ConnectionResetError: + on_closed() task.on_message.append(on_message) task.on_finished.append(on_finished) @@ -574,10 +583,18 @@ class ChatComplete: 'outputed_message': "".join(task.chunks), }) + last_heartbeat = time.time() while True: - if ws.closed: - task.on_message.remove(on_message) - task.on_finished.remove(on_finished) - task.on_error.remove(on_error) + current_time = time.time() + if ws.closed or task.is_finished: + on_closed() break + + if current_time - last_heartbeat >= 15: + try: + await ws.ping('{"event":"ping"}'.encode('utf-8')) + last_heartbeat = current_time + except ConnectionResetError: + on_closed() + break await asyncio.sleep(0.1) diff --git a/api/controller/EmbeddingSearch.py b/api/controller/EmbeddingSearch.py index 0545246..ee9db3f 100644 --- a/api/controller/EmbeddingSearch.py +++ b/api/controller/EmbeddingSearch.py @@ -120,6 +120,8 @@ class EmbeddingSearch: }) if transatcion_id: await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg) + except ConnectionResetError: + pass # Ignore websocket close error except Exception as e: error_msg = str(e) print(error_msg, file=sys.stderr) diff --git a/service/chat_complete.py b/service/chat_complete.py index 39b93f1..ea81ba8 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -443,4 +443,4 @@ class ChatCompleteService: response = await self.openai_api.chat_complete( title_prompt, title_system_prompt ) - return response["message"], response["message_tokens"] + return response["message"][0:250], response["message_tokens"]