You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

297 lines
10 KiB
Python

from __future__ import annotations
import asyncio
import json
import sys
import time
import traceback
from local import noawait
from typing import Optional, Callable, TypedDict
from aiohttp import web
from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
from noawait import NoAwaitPool
from service.chat_complete import ChatCompleteService
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
from service.tiktoken import TikTokenService
import utils.web
chat_complete_tasks: dict[str, ChatCompleteTask] = {}
class ChatCompleteTask:
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
self.task_id = utils.web.generate_uuid()
self.on_message: list[Callable] = []
self.on_error: list[Callable] = []
self.chunks: list[str] = []
self.chat_complete_service: ChatCompleteService
self.chat_complete: ChatCompleteService
self.dbs = dbs
self.user_id = user_id
self.page_title = page_title
self.is_system = is_system
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
async def init(self, question: str, conversation_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None):
self.tiktoken = await TikTokenService.create()
self.mwapi = MediaWikiApi.create()
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
self.chat_complete = await self.chat_complete_service.__aenter__()
if await self.chat_complete.page_index_exists():
question_tokens = await self.tiktoken.get_tokens(question)
extract_limit = embedding_search["limit"] or 10
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
if not self.is_system:
usage_res = await self.mwapi.chat_complete_start_transaction(self.user_id, "chatcomplete", question_tokens, extract_limit)
self.transatcion_id = usage_res.get("transaction_id")
self.point_cost = usage_res.get("point_cost")
chat_res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
user_id=self.user_id, embedding_search=embedding_search)
return chat_res
else:
await self._exit()
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
async def _on_message(self, delta_message: str):
for callback in self.on_message:
await callback(delta_message)
async def _on_error(self, err: Exception):
for callback in self.on_error:
await callback(err)
async def run(self):
try:
chat_res = self.chat_complete.finish_chat_complete(self._on_message)
await self.chat_complete.set_latest_point_cost(self.point_cost)
if self.transatcion_id:
result = await self.mwapi.chat_complete_end_transaction(self.transatcion_id, chat_res["total_tokens"])
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
print(err_msg, file=sys.stderr)
traceback.print_exc()
if self.transatcion_id:
result = await self.mwapi.chat_complete_cancel_transaction(self.transatcion_id, error=err_msg)
await self._on_error(e)
finally:
await self._exit()
async def _exit(self):
await self.chat_complete_service.__aexit__(None, None, None)
del chat_complete_tasks[self.task_id]
@noawait.wrap
async def start(self):
await self.run()
class ChatComplete:
@staticmethod
@utils.web.token_auth
async def get_conversation_chunk_list(request: web.Request):
params = await utils.web.get_param(request, {
"user_id": {
"required": False,
"type": int
},
"conversation_id": {
"required": True,
"type": int
}
})
if request.get("caller") == "user":
user_id = request.get("user")
else:
user_id = params.get("user_id")
conversation_id = params.get("conversation_id")
db = await DatabaseService.create(request.app)
async with db.create_session() as session:
stmt = select(ConversationModel).where(
ConversationModel.id == conversation_id)
conversation_data = await session.scalar(stmt)
if conversation_data is None:
return await utils.web.api_response(-1, error={
"code": "conversation-not-found",
"message": "Conversation not found."
}, http_status=404, request=request)
if conversation_data.user_id != user_id:
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "Permission denied."
}, http_status=403, request=request)
stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \
.where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())
conversation_chunk_result = await session.scalars(stmt)
conversation_chunk_list = []
for result in conversation_chunk_result:
conversation_chunk_list.append({
"id": result.id,
"updated_at": result.updated_at
})
return await utils.web.api_response(1, conversation_chunk_list, request=request)
@staticmethod
@utils.web.token_auth
async def get_conversation_chunk(request: web.Request):
params = await utils.web.get_param(request, {
"user_id": {
"required": False,
"type": int,
},
"chunk_id": {
"required": True,
"type": int,
},
})
if request.get("caller") == "user":
user_id = request.get("user")
else:
user_id = params.get("user_id")
chunk_id = params.get("chunk_id")
dbs = await DatabaseService.create(request.app)
async with dbs.create_session() as session:
stmt = select(ConversationChunkModel).where(
ConversationChunkModel.id == chunk_id)
conversation_data = await session.scalar(stmt)
if conversation_data is None:
return await utils.web.api_response(-1, error={
"code": "conversation-chunk-not-found",
"message": "Conversation chunk not found."
}, http_status=404, request=request)
if conversation_data.conversation.user_id != user_id:
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "Permission denied."
}, http_status=403, request=request)
return await utils.web.api_response(1, conversation_data.__dict__, request=request)
@staticmethod
@utils.web.token_auth
async def get_tokens(request: web.Request):
params = await utils.web.get_param(request, {
"question": {
"type": str,
"required": True
}
})
question = params.get("question")
tiktoken = await TikTokenService.create()
tokens = await tiktoken.get_tokens(question)
return await utils.web.api_response(1, {"tokens": tokens}, request=request)
@staticmethod
@utils.web.token_auth
async def start_chat_complete(request: web.Request):
params = await utils.web.get_param(request, {
"title": {
"type": str,
"required": True,
},
"question": {
"type": str,
"required": True,
},
"conversation_id": {
"type": int,
"required": False,
},
"extract_limit": {
"type": int,
"required": False,
"default": 10,
},
"in_collection": {
"type": bool,
"required": False,
"default": False,
},
})
user_id = request.get("user")
caller = request.get("caller")
page_title = params.get("title")
question = params.get("question")
conversation_id = params.get("conversation_id")
extract_limit = params.get("extract_limit")
in_collection = params.get("in_collection")
dbs = await DatabaseService.create(request.app)
try:
chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user")
init_res = await chat_complete_task.init(question, conversation_id=conversation_id, embedding_search={
"limit": extract_limit,
"in_collection": in_collection,
})
chat_complete_tasks[chat_complete_task.task_id] = chat_complete_task
chat_complete_task.start()
return utils.web.api_response(1, data={
"question_tokens": init_res["question_tokens"],
"extract_doc": init_res["extract_doc"],
"task_id": chat_complete_task.task_id,
}, request=request)
except MediaWikiPageNotFoundException as e:
error_msg = "Page \"%s\" not found." % page_title
return await utils.web.api_response(-1, error={
"code": "page-not-found",
"title": page_title,
"message": error_msg
}, http_status=404, request=request)
except Exception as e:
err_msg = f"Error while processing chat complete request: {e}"
traceback.print_exc()
return await utils.web.api_response(-1, error={
"code": "chat-complete-error",
"message": err_msg
}, http_status=500, request=request)
@staticmethod
@utils.web.token_auth
async def chat_complete_stream(request: web.Request):
pass