You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

280 lines
10 KiB
Python

import asyncio
import json
import time
import traceback
from local import noawait
from typing import Optional
from aiohttp import WSMsgType, web
from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
from noawait import NoAwaitPool
from service.chat_complete import ChatCompleteService
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
from service.tiktoken import TikTokenService
import utils.web
class ChatCompleteTaskList:
def __init__(self, dbs: DatabaseService):
self.on_message = None
self.chunks: list[str] = []
async def run():
pass
@noawait.wrap
async def start(self):
await self.run()
class ChatComplete:
@staticmethod
@utils.web.token_auth
async def get_conversation_chunk_list(request: web.Request):
params = await utils.web.get_param(request, {
"user_id": {
"required": False,
"type": int
},
"conversation_id": {
"required": True,
"type": int
}
})
if request.get("caller") == "user":
user_id = request.get("user")
else:
user_id = params.get("user_id")
conversation_id = params.get("conversation_id")
db = await DatabaseService.create(request.app)
async with db.create_session() as session:
stmt = select(ConversationModel).where(
ConversationModel.id == conversation_id)
conversation_data = await session.scalar(stmt)
if conversation_data is None:
return await utils.web.api_response(-1, error={
"code": "conversation-not-found",
"message": "Conversation not found."
}, http_status=404, request=request)
if conversation_data.user_id != user_id:
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "Permission denied."
}, http_status=403, request=request)
stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \
.where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())
conversation_chunk_result = await session.scalars(stmt)
conversation_chunk_list = []
for result in conversation_chunk_result:
conversation_chunk_list.append({
"id": result.id,
"updated_at": result.updated_at
})
return await utils.web.api_response(1, conversation_chunk_list, request=request)
@staticmethod
@utils.web.token_auth
async def get_conversation_chunk(request: web.Request):
params = await utils.web.get_param(request, {
"user_id": {
"required": False,
"type": int,
},
"chunk_id": {
"required": True,
"type": int,
},
})
if request.get("caller") == "user":
user_id = request.get("user")
else:
user_id = params.get("user_id")
chunk_id = params.get("chunk_id")
dbs = await DatabaseService.create(request.app)
async with dbs.create_session() as session:
stmt = select(ConversationChunkModel).where(
ConversationChunkModel.id == chunk_id)
conversation_data = await session.scalar(stmt)
if conversation_data is None:
return await utils.web.api_response(-1, error={
"code": "conversation-chunk-not-found",
"message": "Conversation chunk not found."
}, http_status=404, request=request)
if conversation_data.conversation.user_id != user_id:
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "Permission denied."
}, http_status=403, request=request)
return await utils.web.api_response(1, conversation_data.__dict__, request=request)
@staticmethod
@utils.web.token_auth
async def get_tokens(request: web.Request):
params = await utils.web.get_param(request, {
"question": {
"type": str,
"required": True
}
})
question = params.get("question")
tiktoken = await TikTokenService.create()
tokens = await tiktoken.get_tokens(question)
return await utils.web.api_response(1, {"tokens": tokens}, request=request)
@staticmethod
@utils.web.token_auth
async def chat_complete(request: web.Request):
params = await utils.web.get_param(request, {
"title": {
"type": str,
"required": True,
},
"question": {
"type": str,
"required": True,
},
"conversation_id": {
"type": int,
"required": False,
},
"extra_limit": {
"type": int,
"required": False,
"default": 10,
},
"in_collection": {
"type": bool,
"required": False,
"default": False,
},
})
user_id = request.get("user")
caller = request.get("caller")
page_title = params.get("title")
question = params.get("question")
conversation_id = params.get("conversation_id")
extra_limit = params.get("extra_limit")
in_collection = params.get("in_collection")
dbs = await DatabaseService.create(request.app)
tiktoken = await TikTokenService.create()
mwapi = MediaWikiApi.create()
if utils.web.is_websocket(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
try:
async with ChatCompleteService(dbs, page_title) as chat_complete_service:
if await chat_complete_service.page_index_exists():
tokens = await tiktoken.get_tokens(question)
transatcion_id = None
point_cost = 0
if request.get("caller") == "user":
usage_res = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit)
transatcion_id = usage_res.get("transaction_id")
point_cost = usage_res.get("point_cost")
async def on_message(text: str):
# Send message to client, start with "+" to indicate it's a message
# use json will make the package 10x larger
await ws.send_str("+" + text)
async def on_extracted_doc(doc: list):
await ws.send_json({
'event': 'extract_doc',
'status': 1,
'doc': doc
})
try:
chat_res = await chat_complete_service \
.prepare_chat_complete(question, conversation_id=conversation_id, user_id=user_id, embedding_search={
"limit": extra_limit,
"in_collection": in_collection,
})
await ws.send_json({
'event': 'done',
'status': 1,
**chat_res,
})
await chat_complete_service.set_latest_point_cost(point_cost)
if transatcion_id:
result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"])
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc()
if not ws.closed:
await ws.send_json({
'event': 'error',
'status': -1,
'message': err_msg,
'error': {
'code': 'internal_error',
'title': page_title,
},
})
if transatcion_id:
result = await mwapi.chat_complete_cancel_transaction(transatcion_id, error=err_msg)
else:
await ws.send_json({
'event': 'error',
'status': -2,
'message': "Page index not found.",
'error': {
'code': 'page_not_found',
'title': page_title,
},
})
# websocket closed
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc()
if not ws.closed:
await ws.send_json({
'event': 'error',
'status': -1,
'message': err_msg,
'error': {
'code': 'internal_error',
'title': page_title,
},
})
finally:
if not ws.closed:
await ws.close()
else:
return await utils.web.api_response(-1, request=request, error={
"code": "protocol-mismatch",
"message": "Protocol mismatch, websocket request expected."
}, http_status=400)