You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

327 lines
12 KiB
Python

import asyncio
import json
import time
import traceback
from aiohttp import WSMsgType, web
from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
from service.chat_complete import ChatCompleteService
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
from service.tiktoken import TikTokenService
import utils.web
class ChatCompleteWebSocketController:
def __init__(self, request: web.Request):
self.request = request
self.ws = None
self.db = None
self.chat_complete = None
self.closed = False
self.refreshed_time = 0
async def run(self):
self.ws = web.WebSocketResponse()
await self.ws.prepare(self.request)
self.refreshed_time = time.time()
self.db = await DatabaseService.create(self.request.app)
self.query = self.request.query
if self.request.get("caller") == "user":
user_id = self.request.get("user")
else:
user_id = self.query.get("user_id")
title = self.query.get("title")
# create heartbeat task
asyncio.ensure_future(self._timeout_task())
async for msg in self.ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
event = data.get('event')
self.refreshed_time = time.time()
if event == 'chatcomplete':
asyncio.ensure_future(self._chatcomplete(data))
if event == 'ping':
await self.ws.send_json({
'event': 'pong'
})
except Exception as e:
print(e)
traceback.print_exc()
await self.ws.send_json({
'event': 'error',
'error': str(e)
})
elif msg.type == WSMsgType.ERROR:
print('ws connection closed with exception %s' %
self.ws.exception())
async def _timeout_task(self):
while not self.closed:
if time.time() - self.refreshed_time > 30:
self.closed = True
await self.ws.close()
return
await asyncio.sleep(1)
async def _chatcomplete(self, params: dict):
question = params.get("question")
conversation_id = params.get("conversation_id")
class ChatComplete:
@staticmethod
@utils.web.token_auth
async def get_conversation_chunk_list(request: web.Request):
params = await utils.web.get_param(request, {
"user_id": {
"required": False,
"type": int
},
"conversation_id": {
"required": True,
"type": int
}
})
if request.get("caller") == "user":
user_id = request.get("user")
else:
user_id = params.get("user_id")
conversation_id = params.get("conversation_id")
db = await DatabaseService.create(request.app)
async with db.create_session() as session:
stmt = select(ConversationModel).where(
ConversationModel.id == conversation_id)
conversation_data = await session.scalar(stmt)
if conversation_data is None:
return await utils.web.api_response(-1, error={
"code": "conversation-not-found",
"message": "Conversation not found."
}, http_status=404, request=request)
if conversation_data.user_id != user_id:
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "Permission denied."
}, http_status=403, request=request)
stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \
.where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())
conversation_chunk_result = await session.scalars(stmt)
conversation_chunk_list = []
for result in conversation_chunk_result:
conversation_chunk_list.append({
"id": result.id,
"updated_at": result.updated_at
})
return await utils.web.api_response(1, conversation_chunk_list, request=request)
@staticmethod
@utils.web.token_auth
async def get_conversation_chunk(request: web.Request):
params = await utils.web.get_param(request, {
"user_id": {
"required": False,
"type": int,
},
"chunk_id": {
"required": True,
"type": int,
},
})
if request.get("caller") == "user":
user_id = request.get("user")
else:
user_id = params.get("user_id")
chunk_id = params.get("chunk_id")
dbs = await DatabaseService.create(request.app)
async with dbs.create_session() as session:
stmt = select(ConversationChunkModel).where(
ConversationChunkModel.id == chunk_id)
conversation_data = await session.scalar(stmt)
if conversation_data is None:
return await utils.web.api_response(-1, error={
"code": "conversation-chunk-not-found",
"message": "Conversation chunk not found."
}, http_status=404, request=request)
if conversation_data.conversation.user_id != user_id:
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "Permission denied."
}, http_status=403, request=request)
return await utils.web.api_response(1, conversation_data.__dict__, request=request)
@staticmethod
@utils.web.token_auth
async def get_tokens(request: web.Request):
params = await utils.web.get_param(request, {
"question": {
"type": str,
"required": True
}
})
question = params.get("question")
tiktoken = await TikTokenService.create()
tokens = await tiktoken.get_tokens(question)
return await utils.web.api_response(1, {"tokens": tokens}, request=request)
@staticmethod
@utils.web.token_auth
async def chat_complete(request: web.Request):
params = await utils.web.get_param(request, {
"title": {
"type": str,
"required": True,
},
"question": {
"type": str,
"required": True,
},
"conversation_id": {
"type": int,
"required": False,
},
"extra_limit": {
"type": int,
"required": False,
"default": 10,
},
"in_collection": {
"type": bool,
"required": False,
"default": False,
},
})
user_id = request.get("user")
caller = request.get("caller")
page_title = params.get("title")
question = params.get("question")
conversation_id = params.get("conversation_id")
extra_limit = params.get("extra_limit")
in_collection = params.get("in_collection")
dbs = await DatabaseService.create(request.app)
tiktoken = await TikTokenService.create()
mwapi = MediaWikiApi.create()
if utils.web.is_websocket(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
try:
async with ChatCompleteService(dbs, page_title) as chat_complete_service:
if await chat_complete_service.page_index_exists():
tokens = await tiktoken.get_tokens(question)
transatcion_id = None
if request.get("caller") == "user":
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit)
async def on_message(text: str):
# Send message to client, start with "+" to indicate it's a message
# use json will make the package 10x larger
await ws.send_str("+" + text)
async def on_extracted_doc(doc: list):
await ws.send_json({
'event': 'extract_doc',
'status': 1,
'doc': doc
})
try:
chat_res = await chat_complete_service \
.chat_complete(question, on_message, on_extracted_doc,
conversation_id=conversation_id, user_id=user_id, embedding_search={
"limit": extra_limit,
"in_collection": in_collection,
})
await ws.send_json({
'event': 'done',
'status': 1,
**chat_res,
})
if transatcion_id:
result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"])
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc()
if not ws.closed:
await ws.send_json({
'event': 'error',
'status': -1,
'message': err_msg,
'error': {
'code': 'internal_error',
'title': page_title,
},
})
if transatcion_id:
result = await mwapi.chat_complete_cancel_transaction(transatcion_id, error=err_msg)
else:
await ws.send_json({
'event': 'error',
'status': -2,
'message': "Page index not found.",
'error': {
'code': 'page_not_found',
'title': page_title,
},
})
# websocket closed
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc()
if not ws.closed:
await ws.send_json({
'event': 'error',
'status': -1,
'message': err_msg,
'error': {
'code': 'internal_error',
'title': page_title,
},
})
finally:
if not ws.closed:
await ws.close()
else:
return await utils.web.api_response(-1, request=request, error={
"code": "protocol-mismatch",
"message": "Protocol mismatch, websocket request expected."
}, http_status=400)