更新ChatComplete的API模式

master
落雨楓 2 years ago
parent 2f68357c1d
commit 6a74050b56

@ -1,26 +1,103 @@
from __future__ import annotations
import asyncio import asyncio
import json import json
import sys
import time import time
import traceback import traceback
from local import noawait from local import noawait
from typing import Optional from typing import Optional, Callable, TypedDict
from aiohttp import WSMsgType, web from aiohttp import web
from sqlalchemy import select from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
from noawait import NoAwaitPool from noawait import NoAwaitPool
from service.chat_complete import ChatCompleteService from service.chat_complete import ChatCompleteService
from service.database import DatabaseService from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
import utils.web import utils.web
class ChatCompleteTaskList: chat_complete_tasks: dict[str, ChatCompleteTask] = {}
def __init__(self, dbs: DatabaseService):
self.on_message = None class ChatCompleteTask:
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
self.task_id = utils.web.generate_uuid()
self.on_message: list[Callable] = []
self.on_error: list[Callable] = []
self.chunks: list[str] = [] self.chunks: list[str] = []
async def run(): self.chat_complete_service: ChatCompleteService
pass self.chat_complete: ChatCompleteService
self.dbs = dbs
self.user_id = user_id
self.page_title = page_title
self.is_system = is_system
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
async def init(self, question: str, conversation_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None):
self.tiktoken = await TikTokenService.create()
self.mwapi = MediaWikiApi.create()
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
self.chat_complete = await self.chat_complete_service.__aenter__()
if await self.chat_complete.page_index_exists():
question_tokens = await self.tiktoken.get_tokens(question)
extract_limit = embedding_search["limit"] or 10
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
if not self.is_system:
usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete", question_tokens, extract_limit)
self.transatcion_id = usage_res.get("transaction_id")
self.point_cost = usage_res.get("point_cost")
chat_res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
user_id=self.user_id, embedding_search=embedding_search)
return chat_res
else:
await self._exit()
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
async def _on_message(self, delta_message: str):
for callback in self.on_message:
await callback(delta_message)
async def _on_error(self, err: Exception):
for callback in self.on_error:
await callback(err)
async def run(self):
try:
chat_res = self.chat_complete.finish_chat_complete(self._on_message)
await self.chat_complete.set_latest_point_cost(self.point_cost)
if self.transatcion_id:
result = await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"])
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
print(err_msg, file=sys.stderr)
traceback.print_exc()
if self.transatcion_id:
result = await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg)
await self._on_error(e)
finally:
await self._exit()
async def _exit(self):
await self.chat_complete_service.__aexit__(None, None, None)
del chat_complete_tasks[self.task_id]
@noawait.wrap @noawait.wrap
async def start(self): async def start(self):
@ -144,7 +221,7 @@ class ChatComplete:
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth
async def chat_complete(request: web.Request): async def start_chat_complete(request: web.Request):
params = await utils.web.get_param(request, { params = await utils.web.get_param(request, {
"title": { "title": {
"type": str, "type": str,
@ -158,7 +235,7 @@ class ChatComplete:
"type": int, "type": int,
"required": False, "required": False,
}, },
"extra_limit": { "extract_limit": {
"type": int, "type": int,
"required": False, "required": False,
"default": 10, "default": 10,
@ -177,103 +254,44 @@ class ChatComplete:
question = params.get("question") question = params.get("question")
conversation_id = params.get("conversation_id") conversation_id = params.get("conversation_id")
extra_limit = params.get("extra_limit") extract_limit = params.get("extract_limit")
in_collection = params.get("in_collection") in_collection = params.get("in_collection")
dbs = await DatabaseService.create(request.app) dbs = await DatabaseService.create(request.app)
tiktoken = await TikTokenService.create()
mwapi = MediaWikiApi.create() try:
if utils.web.is_websocket(request): chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user")
ws = web.WebSocketResponse() init_res = await chat_complete_task.init(question, conversation_id=conversation_id, embedding_search={
await ws.prepare(request) "limit": extract_limit,
"in_collection": in_collection,
try: })
async with ChatCompleteService(dbs, page_title) as chat_complete_service:
if await chat_complete_service.page_index_exists(): chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task
tokens = await tiktoken.get_tokens(question)
chat_complete_task.start()
transatcion_id = None
point_cost = 0 return utils.web.api_response(1, data={
if request.get("caller") == "user": "question_tokens": init_res["question_tokens"],
usage_res = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit) "extract_doc": init_res["extract_doc"],
transatcion_id = usage_res.get("transaction_id") "task_id": chat_complete_task.task_id,
point_cost = usage_res.get("point_cost") }, request=request)
except MediaWikiPageNotFoundException as e:
async def on_message(text: str): error_msg = "Page \"%s\" not found." % page_title
# Send message to client, start with "+" to indicate it's a message return await utils.web.api_response(-1, error={
# use json will make the package 10x larger "code": "page-not-found",
await ws.send_str("+" + text) "title": page_title,
"message": error_msg
async def on_extracted_doc(doc: list): }, http_status=404, request=request)
await ws.send_json({ except Exception as e:
'event': 'extract_doc', err_msg = f"Error while processing chat complete request: {e}"
'status': 1, traceback.print_exc()
'doc': doc
}) return await utils.web.api_response(-1, error={
"code": "chat-complete-error",
try: "message": err_msg
chat_res = await chat_complete_service \ }, http_status=500, request=request)
.prepare_chat_complete(question, conversation_id=conversation_id, user_id=user_id, embedding_search={
"limit": extra_limit, @staticmethod
"in_collection": in_collection, @utils.web.token_auth
}) async def chat_complete_stream(request: web.Request):
await ws.send_json({ pass
'event': 'done',
'status': 1,
**chat_res,
})
await chat_complete_service.set_latest_point_cost(point_cost)
if transatcion_id:
result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"])
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc()
if not ws.closed:
await ws.send_json({
'event': 'error',
'status': -1,
'message': err_msg,
'error': {
'code': 'internal_error',
'title': page_title,
},
})
if transatcion_id:
result = await mwapi.chat_complete_cancel_transaction(transatcion_id, error=err_msg)
else:
await ws.send_json({
'event': 'error',
'status': -2,
'message': "Page index not found.",
'error': {
'code': 'page_not_found',
'title': page_title,
},
})
# websocket closed
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc()
if not ws.closed:
await ws.send_json({
'event': 'error',
'status': -1,
'message': err_msg,
'error': {
'code': 'internal_error',
'title': page_title,
},
})
finally:
if not ws.closed:
await ws.close()
else:
return await utils.web.api_response(-1, request=request, error={
"code": "protocol-mismatch",
"message": "Protocol mismatch, websocket request expected."
}, http_status=400)

@ -29,5 +29,6 @@ def init(app: web.Application):
web.route('*', '/chatcomplete/conversation_chunks', ChatComplete.get_conversation_chunk_list), web.route('*', '/chatcomplete/conversation_chunks', ChatComplete.get_conversation_chunk_list),
web.route('*', '/chatcomplete/conversation_chunk/{id:^\d+}', ChatComplete.get_conversation_chunk), web.route('*', '/chatcomplete/conversation_chunk/{id:^\d+}', ChatComplete.get_conversation_chunk),
web.route('*', '/chatcomplete/message', ChatComplete.chat_complete), web.route('*', '/chatcomplete/message', ChatComplete.start_chat_complete),
web.route('*', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream),
]) ])

@ -83,7 +83,7 @@ class ChatCompleteService:
async def prepare_chat_complete(self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None, async def prepare_chat_complete(self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None,
question_tokens: Optional[int] = None, question_tokens: Optional[int] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServiceResponse: embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse:
if user_id is not None: if user_id is not None:
user_id = int(user_id) user_id = int(user_id)

@ -4,6 +4,7 @@ from typing import Any, Optional, Dict
from aiohttp import web from aiohttp import web
import jwt import jwt
import config import config
import uuid
ParamRule = Dict[str, Any] ParamRule = Dict[str, Any]
@ -86,7 +87,10 @@ async def api_response(status, data=None, error=None, warning=None, http_status=
return web.json_response(ret, status=http_status) return web.json_response(ret, status=http_status)
def is_websocket(request: web.Request): def is_websocket(request: web.Request):
return request.headers.get('Upgrade', '').lower() == 'websocket' return request.headers.get('Upgrade', '').lower() == 'websocket'
def generate_uuid():
return str(uuid.uuid4())
# Auth decorator # Auth decorator
def token_auth(f): def token_auth(f):

Loading…
Cancel
Save