更新ChatComplete的API模式

master
落雨楓 2 years ago
parent 2f68357c1d
commit 6a74050b56

@ -1,26 +1,103 @@
from __future__ import annotations
import asyncio
import json
import sys
import time
import traceback
from local import noawait
from typing import Optional
from aiohttp import WSMsgType, web
from typing import Optional, Callable, TypedDict
from aiohttp import web
from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
from noawait import NoAwaitPool
from service.chat_complete import ChatCompleteService
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
from service.tiktoken import TikTokenService
import utils.web
class ChatCompleteTaskList:
def __init__(self, dbs: DatabaseService):
self.on_message = None
chat_complete_tasks: dict[str, ChatCompleteTask] = {}
class ChatCompleteTask:
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
self.task_id = utils.web.generate_uuid()
self.on_message: list[Callable] = []
self.on_error: list[Callable] = []
self.chunks: list[str] = []
async def run():
pass
self.chat_complete_service: ChatCompleteService
self.chat_complete: ChatCompleteService
self.dbs = dbs
self.user_id = user_id
self.page_title = page_title
self.is_system = is_system
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
async def init(self, question: str, conversation_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None):
self.tiktoken = await TikTokenService.create()
self.mwapi = MediaWikiApi.create()
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
self.chat_complete = await self.chat_complete_service.__aenter__()
if await self.chat_complete.page_index_exists():
question_tokens = await self.tiktoken.get_tokens(question)
extract_limit = embedding_search["limit"] or 10
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
if not self.is_system:
usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete", question_tokens, extract_limit)
self.transatcion_id = usage_res.get("transaction_id")
self.point_cost = usage_res.get("point_cost")
chat_res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
user_id=self.user_id, embedding_search=embedding_search)
return chat_res
else:
await self._exit()
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
async def _on_message(self, delta_message: str):
for callback in self.on_message:
await callback(delta_message)
async def _on_error(self, err: Exception):
for callback in self.on_error:
await callback(err)
async def run(self):
try:
chat_res = self.chat_complete.finish_chat_complete(self._on_message)
await self.chat_complete.set_latest_point_cost(self.point_cost)
if self.transatcion_id:
result = await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"])
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
print(err_msg, file=sys.stderr)
traceback.print_exc()
if self.transatcion_id:
result = await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg)
await self._on_error(e)
finally:
await self._exit()
async def _exit(self):
await self.chat_complete_service.__aexit__(None, None, None)
del chat_complete_tasks[self.task_id]
@noawait.wrap
async def start(self):
@ -144,7 +221,7 @@ class ChatComplete:
@staticmethod
@utils.web.token_auth
async def chat_complete(request: web.Request):
async def start_chat_complete(request: web.Request):
params = await utils.web.get_param(request, {
"title": {
"type": str,
@ -158,7 +235,7 @@ class ChatComplete:
"type": int,
"required": False,
},
"extra_limit": {
"extract_limit": {
"type": int,
"required": False,
"default": 10,
@ -177,103 +254,44 @@ class ChatComplete:
question = params.get("question")
conversation_id = params.get("conversation_id")
extra_limit = params.get("extra_limit")
extract_limit = params.get("extract_limit")
in_collection = params.get("in_collection")
dbs = await DatabaseService.create(request.app)
tiktoken = await TikTokenService.create()
mwapi = MediaWikiApi.create()
if utils.web.is_websocket(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
try:
async with ChatCompleteService(dbs, page_title) as chat_complete_service:
if await chat_complete_service.page_index_exists():
tokens = await tiktoken.get_tokens(question)
transatcion_id = None
point_cost = 0
if request.get("caller") == "user":
usage_res = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit)
transatcion_id = usage_res.get("transaction_id")
point_cost = usage_res.get("point_cost")
async def on_message(text: str):
# Send message to client, start with "+" to indicate it's a message
# use json will make the package 10x larger
await ws.send_str("+" + text)
async def on_extracted_doc(doc: list):
await ws.send_json({
'event': 'extract_doc',
'status': 1,
'doc': doc
})
try:
chat_res = await chat_complete_service \
.prepare_chat_complete(question, conversation_id=conversation_id, user_id=user_id, embedding_search={
"limit": extra_limit,
"in_collection": in_collection,
})
await ws.send_json({
'event': 'done',
'status': 1,
**chat_res,
})
await chat_complete_service.set_latest_point_cost(point_cost)
if transatcion_id:
result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"])
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc()
if not ws.closed:
await ws.send_json({
'event': 'error',
'status': -1,
'message': err_msg,
'error': {
'code': 'internal_error',
'title': page_title,
},
})
if transatcion_id:
result = await mwapi.chat_complete_cancel_transaction(transatcion_id, error=err_msg)
else:
await ws.send_json({
'event': 'error',
'status': -2,
'message': "Page index not found.",
'error': {
'code': 'page_not_found',
'title': page_title,
},
})
# websocket closed
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc()
if not ws.closed:
await ws.send_json({
'event': 'error',
'status': -1,
'message': err_msg,
'error': {
'code': 'internal_error',
'title': page_title,
},
})
finally:
if not ws.closed:
await ws.close()
else:
return await utils.web.api_response(-1, request=request, error={
"code": "protocol-mismatch",
"message": "Protocol mismatch, websocket request expected."
}, http_status=400)
try:
chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user")
init_res = await chat_complete_task.init(question, conversation_id=conversation_id, embedding_search={
"limit": extract_limit,
"in_collection": in_collection,
})
chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task
chat_complete_task.start()
return utils.web.api_response(1, data={
"question_tokens": init_res["question_tokens"],
"extract_doc": init_res["extract_doc"],
"task_id": chat_complete_task.task_id,
}, request=request)
except MediaWikiPageNotFoundException as e:
error_msg = "Page \"%s\" not found." % page_title
return await utils.web.api_response(-1, error={
"code": "page-not-found",
"title": page_title,
"message": error_msg
}, http_status=404, request=request)
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc()
return await utils.web.api_response(-1, error={
"code": "chat-complete-error",
"message": err_msg
}, http_status=500, request=request)
@staticmethod
@utils.web.token_auth
async def chat_complete_stream(request: web.Request):
pass

@ -29,5 +29,6 @@ def init(app: web.Application):
web.route('*', '/chatcomplete/conversation_chunks', ChatComplete.get_conversation_chunk_list),
web.route('*', '/chatcomplete/conversation_chunk/{id:^\d+}', ChatComplete.get_conversation_chunk),
web.route('*', '/chatcomplete/message', ChatComplete.chat_complete),
web.route('*', '/chatcomplete/message', ChatComplete.start_chat_complete),
web.route('*', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream),
])

@ -83,7 +83,7 @@ class ChatCompleteService:
async def prepare_chat_complete(self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None,
question_tokens: Optional[int] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServiceResponse:
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse:
if user_id is not None:
user_id = int(user_id)

@ -4,6 +4,7 @@ from typing import Any, Optional, Dict
from aiohttp import web
import jwt
import config
import uuid
ParamRule = Dict[str, Any]
@ -88,6 +89,9 @@ async def api_response(status, data=None, error=None, warning=None, http_status=
def is_websocket(request: web.Request):
return request.headers.get('Upgrade', '').lower() == 'websocket'
def generate_uuid():
return str(uuid.uuid4())
# Auth decorator
def token_auth(f):
@wraps(f)

Loading…
Cancel
Save