You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
616 lines
22 KiB
Python
616 lines
22 KiB
Python
from __future__ import annotations
|
|
import asyncio
|
|
import time
|
|
import traceback
|
|
from libs.config import Config
|
|
from server.controller.task.ChatCompleteTask import ChatCompleteTask
|
|
from server.model.base import clone_model
|
|
from server.model.chat_complete.bot_persona import BotPersonaHelper
|
|
from server.model.toolbox_ui.conversation import ConversationHelper
|
|
from utils.local import noawait
|
|
from aiohttp import web
|
|
from server.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
|
|
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse, calculate_point_usage
|
|
from service.database import DatabaseService
|
|
from service.mediawiki_api import MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
|
|
from service.tiktoken import TikTokenService
|
|
import utils.web
|
|
|
|
class ChatComplete:
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def get_conversation_chunk_list(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"user_id": {
|
|
"required": False,
|
|
"type": int
|
|
},
|
|
"id": {
|
|
"required": True,
|
|
"type": int
|
|
}
|
|
})
|
|
|
|
if request.get("caller") == "user":
|
|
user_id = request.get("user")
|
|
else:
|
|
user_id = params.get("user_id")
|
|
|
|
conversation_id = params.get("id")
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
|
|
async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
|
|
conversation_info = await conversation_helper.find_by_id(conversation_id)
|
|
|
|
if conversation_info is None:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "conversation-not-found",
|
|
"message": "Conversation not found."
|
|
}, http_status=404, request=request)
|
|
|
|
if conversation_info.user_id != user_id:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "permission-denied",
|
|
"message": "Permission denied."
|
|
}, http_status=403, request=request)
|
|
|
|
conversation_chunk_result = await conversation_chunk_helper.get_chunk_id_list(conversation_id)
|
|
|
|
conversation_chunk_list = []
|
|
|
|
for result in conversation_chunk_result:
|
|
conversation_chunk_list.append(result)
|
|
|
|
return await utils.web.api_response(1, conversation_chunk_list, request=request)
|
|
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def get_conversation_chunk(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"user_id": {
|
|
"required": False,
|
|
"type": int,
|
|
},
|
|
"chunk_id": {
|
|
"required": True,
|
|
"type": int,
|
|
},
|
|
})
|
|
|
|
if request.get("caller") == "user":
|
|
user_id = request.get("user")
|
|
else:
|
|
user_id = params.get("user_id")
|
|
|
|
chunk_id = params.get("chunk_id")
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
|
|
chunk_info = await conversation_chunk_helper.find_by_id(chunk_id)
|
|
if chunk_info is None:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "conversation-chunk-not-found",
|
|
"message": "Conversation chunk not found."
|
|
}, http_status=404, request=request)
|
|
|
|
conversation_info = await conversation_helper.find_by_id(chunk_info.conversation_id)
|
|
|
|
if conversation_info is not None and conversation_info.user_id != user_id:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "permission-denied",
|
|
"message": "Permission denied."
|
|
}, http_status=403, request=request)
|
|
|
|
chunk_dict = {
|
|
"id": chunk_info.id,
|
|
"conversation_id": chunk_info.conversation_id,
|
|
"message_data": chunk_info.message_data,
|
|
"tokens": chunk_info.tokens,
|
|
"updated_at": chunk_info.updated_at,
|
|
}
|
|
|
|
return await utils.web.api_response(1, chunk_dict, request=request)
|
|
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def fork_conversation(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"user_id": {
|
|
"required": False,
|
|
"type": int,
|
|
},
|
|
"id": {
|
|
"required": True,
|
|
"type": int,
|
|
},
|
|
"message_id": {
|
|
"required": False,
|
|
"type": str
|
|
},
|
|
"new_title": {
|
|
"required": False,
|
|
"type": str
|
|
}
|
|
})
|
|
|
|
if request.get("caller") == "user":
|
|
user_id = request.get("user")
|
|
else:
|
|
user_id = params.get("user_id")
|
|
|
|
conversation_id: int = params.get("id")
|
|
packed_message_id: str = params.get("message_id")
|
|
new_title = params.get("new_title")
|
|
|
|
if packed_message_id is not None:
|
|
(chunk_id, msg_id) = packed_message_id.split(",")
|
|
chunk_id = int(chunk_id)
|
|
else:
|
|
chunk_id = None
|
|
msg_id = None
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
|
|
conversation_info = await conversation_helper.find_by_id(conversation_id)
|
|
if conversation_info is None:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "conversation-not-found",
|
|
"message": "Conversation not found."
|
|
}, http_status=404, request=request)
|
|
|
|
if conversation_info.user_id != user_id:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "permission-denied",
|
|
"message": "Permission denied."
|
|
}, http_status=403, request=request)
|
|
|
|
# Clone selected chunk
|
|
if chunk_id is not None:
|
|
chunk_info = await conversation_chunk_helper.find_by_id(chunk_id)
|
|
if chunk_info is None or chunk_info.conversation_id != conversation_id:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "conversation-chunk-not-found",
|
|
"message": "Conversation chunk not found."
|
|
}, http_status=404, request=request)
|
|
else:
|
|
chunk_info = await conversation_chunk_helper.get_newest_chunk(conversation_id)
|
|
|
|
new_conversation: ConversationModel = clone_model(conversation_info)
|
|
|
|
if new_title is not None:
|
|
new_conversation.title = new_title
|
|
|
|
new_conversation = await conversation_helper.add(new_conversation)
|
|
|
|
if chunk_info is not None:
|
|
new_chunk: ConversationChunkModel = clone_model(chunk_info)
|
|
new_chunk.conversation_id = new_conversation.id
|
|
|
|
if msg_id is not None:
|
|
# Remove message after selected message
|
|
split_message_pos = None
|
|
for i in range(0, len(new_chunk.message_data)):
|
|
msg_data: dict = new_chunk.message_data[i]
|
|
if msg_data.get("id") == msg_id:
|
|
split_message_pos = i
|
|
break
|
|
|
|
new_chunk.message_data = new_chunk.message_data[0:split_message_pos + 1]
|
|
|
|
new_chunk.message_data.insert(0, {
|
|
"id": utils.web.generate_uuid(),
|
|
"role": "notice",
|
|
"type": "forked",
|
|
"data": {
|
|
"original_conversation_id": conversation_info.id,
|
|
"original_title": conversation_info.title,
|
|
}
|
|
})
|
|
|
|
# Update conversation description
|
|
last_assistant_message = None
|
|
for msg in new_chunk.message_data:
|
|
if msg["role"] == "assistant":
|
|
last_assistant_message = msg
|
|
|
|
if last_assistant_message is not None:
|
|
new_conversation.description = last_assistant_message["content"][0:150]
|
|
conversation_helper.update(new_conversation)
|
|
|
|
new_chunk = await conversation_chunk_helper.add(new_chunk)
|
|
|
|
return await utils.web.api_response(1, {
|
|
"conversation_id": new_conversation.id,
|
|
}, request=request)
|
|
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def get_tokens(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"question": {
|
|
"type": str,
|
|
"required": True
|
|
}
|
|
})
|
|
|
|
question = params.get("question")
|
|
|
|
tiktoken = await TikTokenService.create()
|
|
tokens = await tiktoken.get_tokens(question)
|
|
|
|
return await utils.web.api_response(1, {"tokens": tokens}, request=request)
|
|
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def get_point_usage(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"question": {
|
|
"type": str,
|
|
"required": True,
|
|
},
|
|
"bot_id": {
|
|
"type": str,
|
|
"required": True,
|
|
},
|
|
"extract_limit": {
|
|
"type": int,
|
|
"required": False,
|
|
"default": 10,
|
|
},
|
|
})
|
|
|
|
user_id = request.get("user")
|
|
|
|
question = params.get("question")
|
|
bot_id = params.get("bot_id")
|
|
|
|
tiktoken = await TikTokenService.create()
|
|
db = await DatabaseService.create(request.app)
|
|
|
|
tokens = await tiktoken.get_tokens(question)
|
|
|
|
estimated_extract_tokens_per_doc = Config.get("estimated_extract_tokens_per_doc", 50, int)
|
|
|
|
predict_tokens = tokens + estimated_extract_tokens_per_doc
|
|
|
|
async with BotPersonaHelper(db) as bot_persona_helper:
|
|
persona_info = await bot_persona_helper.find_by_bot_id(bot_id)
|
|
|
|
if persona_info is None:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "bot-not-found",
|
|
"message": "Bot not found."
|
|
}, http_status=404, request=request)
|
|
|
|
point_usage = calculate_point_usage(predict_tokens, persona_info.cost_fixed, persona_info.cost_fixed_tokens, persona_info.cost_per_token)
|
|
|
|
return await utils.web.api_response(1, {
|
|
"point_usage": point_usage,
|
|
"tokens": tokens,
|
|
"predict_tokens": predict_tokens,
|
|
}, request=request)
|
|
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def get_persona_list(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"page": {
|
|
"type": int,
|
|
"required": False,
|
|
"default": 1,
|
|
}
|
|
})
|
|
|
|
page = params.get("page")
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
async with BotPersonaHelper(db) as bot_persona_helper:
|
|
persona_list = await bot_persona_helper.get_list(page=page)
|
|
page_count = await bot_persona_helper.get_page_count()
|
|
|
|
persona_data_list = []
|
|
for persona in persona_list:
|
|
persona_data_list.append({
|
|
"id": persona.id,
|
|
"bot_id": persona.bot_id,
|
|
"bot_name": persona.bot_name,
|
|
"bot_avatar": persona.bot_avatar,
|
|
"bot_description": persona.bot_description,
|
|
"model_id": persona.model_id,
|
|
"model_name": persona.model_name,
|
|
"cost_fixed": persona.cost_fixed
|
|
})
|
|
|
|
return await utils.web.api_response(1, {
|
|
"list": persona_data_list,
|
|
"page_count": page_count,
|
|
}, request=request)
|
|
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def get_persona_info(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"id": {
|
|
"type": int,
|
|
},
|
|
"bot_id": {
|
|
"type": str,
|
|
}
|
|
})
|
|
|
|
persona_id = params.get("id")
|
|
bot_id = params.get("bot_id")
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
async with BotPersonaHelper(db) as bot_persona_helper:
|
|
if persona_id is not None:
|
|
persona_info = await bot_persona_helper.find_by_id(persona_id)
|
|
elif bot_id is not None:
|
|
persona_info = await bot_persona_helper.find_by_bot_id(bot_id)
|
|
else:
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "invalid-params",
|
|
"message": "Invalid params. Please specify id or bot_id."
|
|
}, http_status=400, request=request)
|
|
|
|
persona_info_res = {}
|
|
for key, value in persona_info.__dict__.items():
|
|
if not key.startswith("_"):
|
|
persona_info_res[key] = value
|
|
|
|
return await utils.web.api_response(1, persona_info_res, request=request)
|
|
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def start_chat_complete(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"title": {
|
|
"type": str,
|
|
"required": True,
|
|
},
|
|
"question": {
|
|
"type": str,
|
|
"required": True,
|
|
},
|
|
"conversation_id": {
|
|
"type": int,
|
|
"required": False,
|
|
},
|
|
"bot_id": {
|
|
"type": str,
|
|
"required": True,
|
|
},
|
|
"extract_limit": {
|
|
"type": int,
|
|
"required": False,
|
|
"default": 10,
|
|
},
|
|
"in_collection": {
|
|
"type": bool,
|
|
"required": False,
|
|
"default": False,
|
|
},
|
|
"edit_message_id": {
|
|
"type": str,
|
|
"required": False,
|
|
},
|
|
})
|
|
|
|
user_id = request.get("user")
|
|
caller = request.get("caller")
|
|
|
|
page_title = params.get("title")
|
|
question = params.get("question")
|
|
conversation_id = params.get("conversation_id")
|
|
bot_id = params.get("bot_id")
|
|
|
|
extract_limit = params.get("extract_limit")
|
|
in_collection = params.get("in_collection")
|
|
|
|
edit_message_id = params.get("edit_message_id")
|
|
|
|
dbs = await DatabaseService.create(request.app)
|
|
|
|
try:
|
|
chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user")
|
|
init_res = await chat_complete_task.init(question, conversation_id=conversation_id, edit_message_id=edit_message_id,
|
|
bot_id=bot_id,
|
|
embedding_search={
|
|
"limit": extract_limit,
|
|
"in_collection": in_collection,
|
|
})
|
|
|
|
noawait.add_task(chat_complete_task.run())
|
|
|
|
return await utils.web.api_response(1, data={
|
|
"conversation_id": init_res["conversation_id"],
|
|
"chunk_id": init_res["chunk_id"],
|
|
"question_tokens": init_res["question_tokens"],
|
|
"extract_doc": init_res["extract_doc"],
|
|
"task_id": chat_complete_task.task_id,
|
|
}, request=request)
|
|
except MediaWikiPageNotFoundException as e:
|
|
error_msg = "Page \"%s\" not found." % page_title
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "page-not-found",
|
|
"title": page_title,
|
|
"message": error_msg
|
|
}, http_status=404, request=request)
|
|
except MediaWikiUserNoEnoughPointsException as e:
|
|
error_msg = "Does not have enough points." % user_id
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "no-enough-points",
|
|
"message": error_msg
|
|
}, http_status=403, request=request)
|
|
except ChatCompleteQuestionTooLongException as e:
|
|
error_msg = "Question too long."
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "question-too-long",
|
|
"limit": e.tokens_limit,
|
|
"current": e.tokens_current,
|
|
"message": error_msg
|
|
}, http_status=400, request=request)
|
|
except Exception as e:
|
|
err_msg = f"Error while processing chat complete request: {e}"
|
|
traceback.print_exc()
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "internal-server-error",
|
|
"message": err_msg
|
|
}, http_status=500, request=request)
|
|
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def chat_complete_stream(request: web.Request):
|
|
if not utils.web.is_websocket(request):
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "websocket-required",
|
|
"message": "This API only accept websocket connection."
|
|
}, http_status=400, request=request)
|
|
|
|
params = await utils.web.get_param(request, {
|
|
"task_id": {
|
|
"type": str,
|
|
"required": True,
|
|
}
|
|
})
|
|
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(request)
|
|
|
|
task_id = params.get("task_id")
|
|
|
|
task = ChatCompleteTask.get_by_id(task_id)
|
|
if task is None:
|
|
await ws.send_json({
|
|
'event': 'error',
|
|
'status': -1,
|
|
'message': "Task not found.",
|
|
'error': {
|
|
'code': "task-not-found",
|
|
'info': "Task not found.",
|
|
},
|
|
})
|
|
return
|
|
|
|
if request.get("caller") == "user":
|
|
user_id = request.get("user")
|
|
if task.user_id != user_id:
|
|
await ws.send_json({
|
|
'event': 'error',
|
|
'status': -1,
|
|
'message': "Permission denied.",
|
|
'error': {
|
|
'code': "permission-denied",
|
|
'info': "Permission denied.",
|
|
},
|
|
})
|
|
return
|
|
|
|
if task.is_finished:
|
|
if task.error is not None:
|
|
await ws.send_json({
|
|
'event': 'error',
|
|
'status': -1,
|
|
'message': str(task.error),
|
|
'error': {
|
|
'code': "internal-server-error",
|
|
'info': str(task.error),
|
|
},
|
|
})
|
|
await ws.close()
|
|
elif task.result is not None:
|
|
await ws.send_json({
|
|
'event': 'connected',
|
|
'status': 1,
|
|
'outputed_message': "".join(task.chunks),
|
|
})
|
|
await ws.send_json({
|
|
'event': 'finished',
|
|
'status': 1,
|
|
'result': task.result
|
|
})
|
|
await ws.close()
|
|
else:
|
|
async def on_closed():
|
|
task.on_message.remove(on_message)
|
|
task.on_finished.remove(on_finished)
|
|
task.on_error.remove(on_error)
|
|
|
|
async def on_message(delta_message: str):
|
|
try:
|
|
await ws.send_str("+" + delta_message)
|
|
except ConnectionResetError:
|
|
await on_closed()
|
|
|
|
async def on_finished(result: ChatCompleteServiceResponse):
|
|
try:
|
|
ignored_keys = ["message"]
|
|
response_result = {
|
|
"point_usage": task.point_usage,
|
|
}
|
|
for k, v in result.items():
|
|
if k not in ignored_keys:
|
|
response_result[k] = v
|
|
await ws.send_json({
|
|
'event': 'finished',
|
|
'status': 1,
|
|
'result': response_result
|
|
})
|
|
|
|
await ws.close()
|
|
except ConnectionResetError:
|
|
await on_closed()
|
|
|
|
async def on_error(err: Exception):
|
|
try:
|
|
await ws.send_json({
|
|
'event': 'error',
|
|
'status': -1,
|
|
'message': str(err),
|
|
'error': {
|
|
'code': "internal-server-error",
|
|
'info': str(err),
|
|
},
|
|
})
|
|
|
|
await ws.close()
|
|
except ConnectionResetError:
|
|
await on_closed()
|
|
|
|
task.on_message.append(on_message)
|
|
task.on_finished.append(on_finished)
|
|
task.on_error.append(on_error)
|
|
|
|
# Send received message
|
|
await ws.send_json({
|
|
'event': 'connected',
|
|
'status': 1,
|
|
'outputed_message': "".join(task.chunks),
|
|
})
|
|
|
|
last_heartbeat = time.time()
|
|
while True:
|
|
current_time = time.time()
|
|
if ws.closed or task.is_finished:
|
|
await on_closed()
|
|
break
|
|
|
|
if current_time - last_heartbeat >= 15:
|
|
try:
|
|
await ws.ping('{"event":"ping"}'.encode('utf-8'))
|
|
last_heartbeat = current_time
|
|
except ConnectionResetError:
|
|
await on_closed()
|
|
break
|
|
await asyncio.sleep(0.1)
|