You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
327 lines
12 KiB
Python
327 lines
12 KiB
Python
import asyncio
|
|
import json
|
|
import time
|
|
import traceback
|
|
from aiohttp import WSMsgType, web
|
|
from sqlalchemy import select
|
|
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
|
|
from service.chat_complete import ChatCompleteService
|
|
from service.database import DatabaseService
|
|
from service.mediawiki_api import MediaWikiApi
|
|
from service.tiktoken import TikTokenService
|
|
import utils.web
|
|
|
|
|
|
class ChatCompleteWebSocketController:
|
|
def __init__(self, request: web.Request):
|
|
self.request = request
|
|
self.ws = None
|
|
self.db = None
|
|
self.chat_complete = None
|
|
|
|
self.closed = False
|
|
|
|
self.refreshed_time = 0
|
|
|
|
async def run(self):
|
|
self.ws = web.WebSocketResponse()
|
|
await self.ws.prepare(self.request)
|
|
self.refreshed_time = time.time()
|
|
|
|
self.db = await DatabaseService.create(self.request.app)
|
|
|
|
self.query = self.request.query
|
|
if self.request.get("caller") == "user":
|
|
user_id = self.request.get("user")
|
|
else:
|
|
user_id = self.query.get("user_id")
|
|
title = self.query.get("title")
|
|
|
|
# create heartbeat task
|
|
asyncio.ensure_future(self._timeout_task())
|
|
|
|
async for msg in self.ws:
|
|
if msg.type == WSMsgType.TEXT:
|
|
try:
|
|
data = json.loads(msg.data)
|
|
event = data.get('event')
|
|
self.refreshed_time = time.time()
|
|
if event == 'chatcomplete':
|
|
asyncio.ensure_future(self._chatcomplete(data))
|
|
if event == 'ping':
|
|
await self.ws.send_json({
|
|
'event': 'pong'
|
|
})
|
|
except Exception as e:
|
|
print(e)
|
|
traceback.print_exc()
|
|
await self.ws.send_json({
|
|
'event': 'error',
|
|
'error': str(e)
|
|
})
|
|
elif msg.type == WSMsgType.ERROR:
|
|
print('ws connection closed with exception %s' %
|
|
self.ws.exception())
|
|
|
|
async def _timeout_task(self):
|
|
while not self.closed:
|
|
if time.time() - self.refreshed_time > 30:
|
|
self.closed = True
|
|
await self.ws.close()
|
|
return
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
async def _chatcomplete(self, params: dict):
|
|
question = params.get("question")
|
|
conversation_id = params.get("conversation_id")
|
|
|
|
|
|
class ChatComplete:
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def get_conversation_chunk_list(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"user_id": {
|
|
"required": False,
|
|
"type": int
|
|
},
|
|
"conversation_id": {
|
|
"required": True,
|
|
"type": int
|
|
}
|
|
})
|
|
|
|
if request.get("caller") == "user":
|
|
user_id = request.get("user")
|
|
else:
|
|
user_id = params.get("user_id")
|
|
|
|
conversation_id = params.get("conversation_id")
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
|
|
async with db.create_session() as session:
|
|
stmt = select(ConversationModel).where(
|
|
ConversationModel.id == conversation_id)
|
|
|
|
conversation_data = await session.scalar(stmt)
|
|
|
|
if conversation_data is None:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "conversation-not-found",
|
|
"message": "Conversation not found."
|
|
}, http_status=404, request=request)
|
|
|
|
if conversation_data.user_id != user_id:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "permission-denied",
|
|
"message": "Permission denied."
|
|
}, http_status=403, request=request)
|
|
|
|
stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \
|
|
.where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())
|
|
|
|
conversation_chunk_result = await session.scalars(stmt)
|
|
|
|
conversation_chunk_list = []
|
|
|
|
for result in conversation_chunk_result:
|
|
conversation_chunk_list.append({
|
|
"id": result.id,
|
|
"updated_at": result.updated_at
|
|
})
|
|
|
|
return await utils.web.api_response(1, conversation_chunk_list, request=request)
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def get_conversation_chunk(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"user_id": {
|
|
"required": False,
|
|
"type": int,
|
|
},
|
|
"chunk_id": {
|
|
"required": True,
|
|
"type": int,
|
|
},
|
|
})
|
|
|
|
if request.get("caller") == "user":
|
|
user_id = request.get("user")
|
|
else:
|
|
user_id = params.get("user_id")
|
|
|
|
chunk_id = params.get("chunk_id")
|
|
|
|
dbs = await DatabaseService.create(request.app)
|
|
async with dbs.create_session() as session:
|
|
stmt = select(ConversationChunkModel).where(
|
|
ConversationChunkModel.id == chunk_id)
|
|
|
|
conversation_data = await session.scalar(stmt)
|
|
|
|
if conversation_data is None:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "conversation-chunk-not-found",
|
|
"message": "Conversation chunk not found."
|
|
}, http_status=404, request=request)
|
|
|
|
if conversation_data.conversation.user_id != user_id:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "permission-denied",
|
|
"message": "Permission denied."
|
|
}, http_status=403, request=request)
|
|
|
|
return await utils.web.api_response(1, conversation_data.__dict__, request=request)
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def get_tokens(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"question": {
|
|
"type": str,
|
|
"required": True
|
|
}
|
|
})
|
|
|
|
question = params.get("question")
|
|
|
|
tiktoken = await TikTokenService.create()
|
|
tokens = await tiktoken.get_tokens(question)
|
|
|
|
return await utils.web.api_response(1, {"tokens": tokens}, request=request)
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def chat_complete(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"title": {
|
|
"type": str,
|
|
"required": True,
|
|
},
|
|
"question": {
|
|
"type": str,
|
|
"required": True,
|
|
},
|
|
"conversation_id": {
|
|
"type": int,
|
|
"required": False,
|
|
},
|
|
"extra_limit": {
|
|
"type": int,
|
|
"required": False,
|
|
"default": 10,
|
|
},
|
|
"in_collection": {
|
|
"type": bool,
|
|
"required": False,
|
|
"default": False,
|
|
},
|
|
})
|
|
|
|
user_id = request.get("user")
|
|
caller = request.get("caller")
|
|
|
|
page_title = params.get("title")
|
|
question = params.get("question")
|
|
conversation_id = params.get("conversation_id")
|
|
|
|
extra_limit = params.get("extra_limit")
|
|
in_collection = params.get("in_collection")
|
|
|
|
dbs = await DatabaseService.create(request.app)
|
|
tiktoken = await TikTokenService.create()
|
|
mwapi = MediaWikiApi.create()
|
|
if utils.web.is_websocket(request):
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(request)
|
|
|
|
try:
|
|
async with ChatCompleteService(dbs, page_title) as chat_complete_service:
|
|
if await chat_complete_service.page_index_exists():
|
|
tokens = await tiktoken.get_tokens(question)
|
|
|
|
transatcion_id = None
|
|
if request.get("caller") == "user":
|
|
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit)
|
|
|
|
async def on_message(text: str):
|
|
# Send message to client, start with "+" to indicate it's a message
|
|
# use json will make the package 10x larger
|
|
await ws.send_str("+" + text)
|
|
|
|
async def on_extracted_doc(doc: list):
|
|
await ws.send_json({
|
|
'event': 'extract_doc',
|
|
'status': 1,
|
|
'doc': doc
|
|
})
|
|
|
|
try:
|
|
chat_res = await chat_complete_service \
|
|
.chat_complete(question, on_message, on_extracted_doc,
|
|
conversation_id=conversation_id, user_id=user_id, embedding_search={
|
|
"limit": extra_limit,
|
|
"in_collection": in_collection,
|
|
})
|
|
await ws.send_json({
|
|
'event': 'done',
|
|
'status': 1,
|
|
**chat_res,
|
|
})
|
|
|
|
if transatcion_id:
|
|
result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"])
|
|
except Exception as e:
|
|
err_msg = f"Error while processing chat complete request: {e}"
|
|
traceback.print_exc()
|
|
|
|
if not ws.closed:
|
|
await ws.send_json({
|
|
'event': 'error',
|
|
'status': -1,
|
|
'message': err_msg,
|
|
'error': {
|
|
'code': 'internal_error',
|
|
'title': page_title,
|
|
},
|
|
})
|
|
if transatcion_id:
|
|
result = await mwapi.chat_complete_cancel_transaction(transatcion_id, error=err_msg)
|
|
else:
|
|
await ws.send_json({
|
|
'event': 'error',
|
|
'status': -2,
|
|
'message': "Page index not found.",
|
|
'error': {
|
|
'code': 'page_not_found',
|
|
'title': page_title,
|
|
},
|
|
})
|
|
|
|
# websocket closed
|
|
except Exception as e:
|
|
err_msg = f"Error while processing chat complete request: {e}"
|
|
traceback.print_exc()
|
|
|
|
if not ws.closed:
|
|
await ws.send_json({
|
|
'event': 'error',
|
|
'status': -1,
|
|
'message': err_msg,
|
|
'error': {
|
|
'code': 'internal_error',
|
|
'title': page_title,
|
|
},
|
|
})
|
|
finally:
|
|
if not ws.closed:
|
|
await ws.close()
|
|
else:
|
|
return await utils.web.api_response(-1, request=request, error={
|
|
"code": "protocol-mismatch",
|
|
"message": "Protocol mismatch, websocket request expected."
|
|
}, http_status=400)
|