避免websocket连接关闭时报错

master
落雨楓 2 years ago
parent 7a3ad2afbc
commit 9abc94fe04

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import sys
import time import time
import traceback import traceback
from api.controller.task.ChatCompleteTask import ChatCompleteTask from api.controller.task.ChatCompleteTask import ChatCompleteTask
@ -8,14 +7,10 @@ from api.model.base import clone_model
from api.model.chat_complete.bot_persona import BotPersonaHelper from api.model.chat_complete.bot_persona import BotPersonaHelper
from api.model.toolkit_ui.conversation import ConversationHelper from api.model.toolkit_ui.conversation import ConversationHelper
from local import noawait from local import noawait
from typing import Optional, Callable, TypedDict
from aiohttp import web from aiohttp import web
from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
from noawait import NoAwaitPool from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteService, ChatCompleteServiceResponse
from service.database import DatabaseService from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
import utils.web import utils.web
@ -531,10 +526,19 @@ class ChatComplete:
}) })
await ws.close() await ws.close()
else: else:
async def on_closed():
task.on_message.remove(on_message)
task.on_finished.remove(on_finished)
task.on_error.remove(on_error)
async def on_message(delta_message: str): async def on_message(delta_message: str):
try:
await ws.send_str("+" + delta_message) await ws.send_str("+" + delta_message)
except ConnectionResetError:
on_closed()
async def on_finished(result: ChatCompleteServiceResponse): async def on_finished(result: ChatCompleteServiceResponse):
try:
ignored_keys = ["message"] ignored_keys = ["message"]
response_result = { response_result = {
"point_cost": task.point_cost, "point_cost": task.point_cost,
@ -549,8 +553,11 @@ class ChatComplete:
}) })
await ws.close() await ws.close()
except ConnectionResetError:
on_closed()
async def on_error(err: Exception): async def on_error(err: Exception):
try:
await ws.send_json({ await ws.send_json({
'event': 'error', 'event': 'error',
'status': -1, 'status': -1,
@ -562,6 +569,8 @@ class ChatComplete:
}) })
await ws.close() await ws.close()
except ConnectionResetError:
on_closed()
task.on_message.append(on_message) task.on_message.append(on_message)
task.on_finished.append(on_finished) task.on_finished.append(on_finished)
@ -574,10 +583,18 @@ class ChatComplete:
'outputed_message': "".join(task.chunks), 'outputed_message': "".join(task.chunks),
}) })
last_heartbeat = time.time()
while True: while True:
if ws.closed: current_time = time.time()
task.on_message.remove(on_message) if ws.closed or task.is_finished:
task.on_finished.remove(on_finished) on_closed()
task.on_error.remove(on_error) break
if current_time - last_heartbeat >= 15:
try:
await ws.ping('{"event":"ping"}'.encode('utf-8'))
last_heartbeat = current_time
except ConnectionResetError:
on_closed()
break break
await asyncio.sleep(0.1) await asyncio.sleep(0.1)

@ -120,6 +120,8 @@ class EmbeddingSearch:
}) })
if transatcion_id: if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg) await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
except ConnectionResetError:
pass # Ignore websocket close error
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
print(error_msg, file=sys.stderr) print(error_msg, file=sys.stderr)

@ -443,4 +443,4 @@ class ChatCompleteService:
response = await self.openai_api.chat_complete( response = await self.openai_api.chat_complete(
title_prompt, title_system_prompt title_prompt, title_system_prompt
) )
return response["message"], response["message_tokens"] return response["message"][0:250], response["message_tokens"]

Loading…
Cancel
Save