|
|
|
@ -1,6 +1,5 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
import asyncio
|
|
|
|
|
import sys
|
|
|
|
|
import time
|
|
|
|
|
import traceback
|
|
|
|
|
from api.controller.task.ChatCompleteTask import ChatCompleteTask
|
|
|
|
@ -8,14 +7,10 @@ from api.model.base import clone_model
|
|
|
|
|
from api.model.chat_complete.bot_persona import BotPersonaHelper
|
|
|
|
|
from api.model.toolkit_ui.conversation import ConversationHelper
|
|
|
|
|
from local import noawait
|
|
|
|
|
from typing import Optional, Callable, TypedDict
|
|
|
|
|
from aiohttp import web
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
|
|
|
|
|
from noawait import NoAwaitPool
|
|
|
|
|
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteService, ChatCompleteServiceResponse
|
|
|
|
|
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse
|
|
|
|
|
from service.database import DatabaseService
|
|
|
|
|
from service.embedding_search import EmbeddingSearchArgs
|
|
|
|
|
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
|
|
|
|
|
from service.tiktoken import TikTokenService
|
|
|
|
|
import utils.web
|
|
|
|
@ -531,10 +526,19 @@ class ChatComplete:
|
|
|
|
|
})
|
|
|
|
|
await ws.close()
|
|
|
|
|
else:
|
|
|
|
|
async def on_closed():
|
|
|
|
|
task.on_message.remove(on_message)
|
|
|
|
|
task.on_finished.remove(on_finished)
|
|
|
|
|
task.on_error.remove(on_error)
|
|
|
|
|
|
|
|
|
|
async def on_message(delta_message: str):
|
|
|
|
|
try:
|
|
|
|
|
await ws.send_str("+" + delta_message)
|
|
|
|
|
except ConnectionResetError:
|
|
|
|
|
on_closed()
|
|
|
|
|
|
|
|
|
|
async def on_finished(result: ChatCompleteServiceResponse):
|
|
|
|
|
try:
|
|
|
|
|
ignored_keys = ["message"]
|
|
|
|
|
response_result = {
|
|
|
|
|
"point_cost": task.point_cost,
|
|
|
|
@ -549,8 +553,11 @@ class ChatComplete:
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
await ws.close()
|
|
|
|
|
except ConnectionResetError:
|
|
|
|
|
on_closed()
|
|
|
|
|
|
|
|
|
|
async def on_error(err: Exception):
|
|
|
|
|
try:
|
|
|
|
|
await ws.send_json({
|
|
|
|
|
'event': 'error',
|
|
|
|
|
'status': -1,
|
|
|
|
@ -562,6 +569,8 @@ class ChatComplete:
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
await ws.close()
|
|
|
|
|
except ConnectionResetError:
|
|
|
|
|
on_closed()
|
|
|
|
|
|
|
|
|
|
task.on_message.append(on_message)
|
|
|
|
|
task.on_finished.append(on_finished)
|
|
|
|
@ -574,10 +583,18 @@ class ChatComplete:
|
|
|
|
|
'outputed_message': "".join(task.chunks),
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
last_heartbeat = time.time()
|
|
|
|
|
while True:
|
|
|
|
|
if ws.closed:
|
|
|
|
|
task.on_message.remove(on_message)
|
|
|
|
|
task.on_finished.remove(on_finished)
|
|
|
|
|
task.on_error.remove(on_error)
|
|
|
|
|
current_time = time.time()
|
|
|
|
|
if ws.closed or task.is_finished:
|
|
|
|
|
on_closed()
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if current_time - last_heartbeat >= 15:
|
|
|
|
|
try:
|
|
|
|
|
await ws.ping('{"event":"ping"}'.encode('utf-8'))
|
|
|
|
|
last_heartbeat = current_time
|
|
|
|
|
except ConnectionResetError:
|
|
|
|
|
on_closed()
|
|
|
|
|
break
|
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
|