|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
|
|
import time
|
|
|
|
import traceback
|
|
|
|
from typing import Optional, Tuple, TypedDict
|
|
|
|
|
|
|
|
from server.model.chat_complete.bot_persona import BotPersonaHelper
|
|
|
|
from server.model.chat_complete.conversation import (
|
|
|
|
ConversationChunkHelper,
|
|
|
|
ConversationChunkModel,
|
|
|
|
)
|
|
|
|
import sys
|
|
|
|
from server.model.toolbox_ui.conversation import ConversationHelper, ConversationModel
|
|
|
|
|
|
|
|
from libs.config import Config
|
|
|
|
import utils.config, utils.web
|
|
|
|
|
|
|
|
from aiohttp import web
|
|
|
|
from sqlalchemy.orm.attributes import flag_modified
|
|
|
|
|
|
|
|
from service.database import DatabaseService
|
|
|
|
from service.embedding_search import EmbeddingSearchArgs, EmbeddingSearchService
|
|
|
|
from service.mediawiki_api import MediaWikiApi
|
|
|
|
from service.openai_api import OpenAIApi, OpenAIApiTypeInvalidException
|
|
|
|
from service.tiktoken import TikTokenService
|
|
|
|
|
|
|
|
class ChatCompleteQuestionTooLongException(Exception):
|
|
|
|
def __init__(self, tokens_limit: int, tokens_current: int):
|
|
|
|
super().__init__(f"Question too long: {tokens_current} > {tokens_limit}")
|
|
|
|
self.tokens_limit = tokens_limit
|
|
|
|
self.tokens_current = tokens_current
|
|
|
|
|
|
|
|
class ChatCompleteServicePrepareResponse(TypedDict):
|
|
|
|
extract_doc: list
|
|
|
|
question_tokens: int
|
|
|
|
conversation_id: int
|
|
|
|
chunk_id: int
|
|
|
|
api_id: str
|
|
|
|
|
|
|
|
class ChatCompleteServiceResponse(TypedDict):
|
|
|
|
message: str
|
|
|
|
message_tokens: int
|
|
|
|
total_tokens: int
|
|
|
|
finish_reason: str
|
|
|
|
question_message_id: str
|
|
|
|
response_message_id: str
|
|
|
|
delta_data: dict
|
|
|
|
|
|
|
|
class ChatCompleteService:
|
|
|
|
def __init__(self, dbs: DatabaseService, title: str):
|
|
|
|
self.dbs = dbs
|
|
|
|
|
|
|
|
self.title = title
|
|
|
|
self.base_title = title.split("/")[0]
|
|
|
|
|
|
|
|
self.embedding_search = EmbeddingSearchService(dbs, title)
|
|
|
|
self.conversation_helper = ConversationHelper(dbs)
|
|
|
|
self.conversation_chunk_helper = ConversationChunkHelper(dbs)
|
|
|
|
self.bot_persona_helper = BotPersonaHelper(dbs)
|
|
|
|
|
|
|
|
|
|
|
|
self.conversation_info: Optional[ConversationModel] = None
|
|
|
|
self.conversation_chunk: Optional[ConversationChunkModel] = None
|
|
|
|
|
|
|
|
self.openai_api: OpenAIApi = None
|
|
|
|
self.tiktoken: TikTokenService = None
|
|
|
|
|
|
|
|
self.extract_doc: list = None
|
|
|
|
|
|
|
|
self.mwapi = MediaWikiApi.create()
|
|
|
|
|
|
|
|
self.user_id = 0
|
|
|
|
self.question = ""
|
|
|
|
self.question_tokens: Optional[int] = None
|
|
|
|
self.bot_id: str = ""
|
|
|
|
self.model: Optional[str] = None
|
|
|
|
self.conversation_id: Optional[int] = None
|
|
|
|
self.conversation_start_time: Optional[int] = None
|
|
|
|
|
|
|
|
self.chat_system_prompt = ""
|
|
|
|
|
|
|
|
self.delta_data = {}
|
|
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
|
self.tiktoken = await TikTokenService.create()
|
|
|
|
|
|
|
|
await self.embedding_search.__aenter__()
|
|
|
|
await self.conversation_helper.__aenter__()
|
|
|
|
await self.conversation_chunk_helper.__aenter__()
|
|
|
|
await self.bot_persona_helper.__aenter__()
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
|
|
await self.embedding_search.__aexit__(exc_type, exc, tb)
|
|
|
|
await self.conversation_helper.__aexit__(exc_type, exc, tb)
|
|
|
|
await self.conversation_chunk_helper.__aexit__(exc_type, exc, tb)
|
|
|
|
await self.bot_persona_helper.__aexit__(exc_type, exc, tb)
|
|
|
|
|
|
|
|
async def page_index_exists(self):
|
|
|
|
return await self.embedding_search.page_index_exists(False)
|
|
|
|
|
|
|
|
async def get_question_tokens(self, question: str):
|
|
|
|
return await self.tiktoken.get_tokens(question)
|
|
|
|
|
|
|
|
async def prepare_chat_complete(
|
|
|
|
self,
|
|
|
|
question: str,
|
|
|
|
bot_id: str,
|
|
|
|
conversation_id: Optional[str] = None,
|
|
|
|
user_id: Optional[int] = None,
|
|
|
|
question_tokens: Optional[int] = None,
|
|
|
|
edit_message_id: Optional[str] = None,
|
|
|
|
embedding_search: Optional[EmbeddingSearchArgs] = None,
|
|
|
|
) -> ChatCompleteServicePrepareResponse:
|
|
|
|
if user_id is not None:
|
|
|
|
user_id = int(user_id)
|
|
|
|
|
|
|
|
self.user_id = user_id
|
|
|
|
self.question = question
|
|
|
|
self.conversation_start_time = int(time.time())
|
|
|
|
self.bot_id = bot_id
|
|
|
|
|
|
|
|
self.conversation_info = None
|
|
|
|
if conversation_id is not None:
|
|
|
|
self.conversation_id = int(conversation_id)
|
|
|
|
self.conversation_info = await self.conversation_helper.find_by_id(
|
|
|
|
self.conversation_id
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.conversation_id = None
|
|
|
|
|
|
|
|
if self.conversation_info is not None:
|
|
|
|
if self.conversation_info.user_id != user_id:
|
|
|
|
raise web.HTTPUnauthorized()
|
|
|
|
|
|
|
|
if question_tokens is None:
|
|
|
|
self.question_tokens = await self.get_question_tokens(question)
|
|
|
|
else:
|
|
|
|
self.question_tokens = question_tokens
|
|
|
|
if self.question_tokens == 0:
|
|
|
|
self.question_tokens = len(question) * 3
|
|
|
|
|
|
|
|
max_input_tokens = Config.get("chatcomplete.max_input_tokens", 768, int)
|
|
|
|
if (
|
|
|
|
self.question_tokens > max_input_tokens
|
|
|
|
):
|
|
|
|
# If the question is too long, we need to truncate it
|
|
|
|
raise ChatCompleteQuestionTooLongException(max_input_tokens, self.question_tokens)
|
|
|
|
|
|
|
|
if self.conversation_info is not None:
|
|
|
|
self.bot_id = self.conversation_info.extra.get("bot_id") or "default"
|
|
|
|
|
|
|
|
bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id)
|
|
|
|
|
|
|
|
if bot_persona is None:
|
|
|
|
self.bot_id = "default"
|
|
|
|
self.model = None
|
|
|
|
bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id)
|
|
|
|
else:
|
|
|
|
self.model = bot_persona.model_id
|
|
|
|
self.chat_system_prompt = bot_persona.system_prompt
|
|
|
|
|
|
|
|
default_api = Config.get("chatcomplete.default_api", None, str)
|
|
|
|
try:
|
|
|
|
self.openai_api = OpenAIApi.create(bot_persona.api_id or default_api)
|
|
|
|
except OpenAIApiTypeInvalidException:
|
|
|
|
print(f"Invalid API type: {bot_persona.api_id}", file=sys.stderr)
|
|
|
|
self.openai_api = OpenAIApi.create(default_api)
|
|
|
|
|
|
|
|
self.conversation_chunk = None
|
|
|
|
if self.conversation_info is not None:
|
|
|
|
chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id)
|
|
|
|
|
|
|
|
if edit_message_id and "," in edit_message_id:
|
|
|
|
(edit_chunk_id, edit_msg_id) = edit_message_id.split(",")
|
|
|
|
edit_chunk_id = int(edit_chunk_id)
|
|
|
|
|
|
|
|
# Remove overrided conversation chunks
|
|
|
|
start_overrided = False
|
|
|
|
should_remove_chunk_ids = []
|
|
|
|
for chunk_id in chunk_id_list:
|
|
|
|
if start_overrided:
|
|
|
|
should_remove_chunk_ids.append(chunk_id)
|
|
|
|
else:
|
|
|
|
if chunk_id == edit_chunk_id:
|
|
|
|
start_overrided = True
|
|
|
|
|
|
|
|
if len(should_remove_chunk_ids) > 0:
|
|
|
|
await self.conversation_chunk_helper.remove(should_remove_chunk_ids)
|
|
|
|
|
|
|
|
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(
|
|
|
|
self.conversation_id
|
|
|
|
)
|
|
|
|
# Remove outdated message
|
|
|
|
edit_message_pos = None
|
|
|
|
old_tokens = 0
|
|
|
|
for i in range(0, len(self.conversation_chunk.message_data)):
|
|
|
|
msg_data: dict = self.conversation_chunk.message_data[i]
|
|
|
|
if msg_data.get("id") == edit_msg_id:
|
|
|
|
edit_message_pos = i
|
|
|
|
break
|
|
|
|
if "tokens" in msg_data and msg_data["tokens"] is not None:
|
|
|
|
old_tokens += msg_data["tokens"]
|
|
|
|
|
|
|
|
if edit_message_pos:
|
|
|
|
self.conversation_chunk.message_data = self.conversation_chunk.message_data[0:edit_message_pos]
|
|
|
|
flag_modified(self.conversation_chunk, "message_data")
|
|
|
|
self.conversation_chunk.tokens = old_tokens
|
|
|
|
else:
|
|
|
|
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(
|
|
|
|
self.conversation_id
|
|
|
|
)
|
|
|
|
|
|
|
|
# If the conversation is too long, we need to make a summary
|
|
|
|
max_memory_tokens = Config.get("chatcomplete.max_memory_tokens", 1280, int)
|
|
|
|
if self.conversation_chunk.tokens > max_memory_tokens:
|
|
|
|
summary, tokens = await self.make_summary(
|
|
|
|
self.conversation_chunk.message_data
|
|
|
|
)
|
|
|
|
new_message_log = [
|
|
|
|
{
|
|
|
|
"role": "summary",
|
|
|
|
"content": summary,
|
|
|
|
"tokens": tokens,
|
|
|
|
"time": int(time.time()),
|
|
|
|
}
|
|
|
|
]
|
|
|
|
|
|
|
|
self.conversation_chunk = ConversationChunkModel(
|
|
|
|
conversation_id=self.conversation_id,
|
|
|
|
message_data=new_message_log,
|
|
|
|
tokens=tokens,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk)
|
|
|
|
else:
|
|
|
|
# 创建新对话
|
|
|
|
|
|
|
|
# 默认聊天记录
|
|
|
|
init_message_data = []
|
|
|
|
if bot_persona is not None:
|
|
|
|
current_time = int(time.time())
|
|
|
|
if bot_persona.message_log is not None:
|
|
|
|
for message in bot_persona.message_log:
|
|
|
|
message["id"] = utils.web.generate_uuid()
|
|
|
|
message["time"] = current_time
|
|
|
|
init_message_data.append(message)
|
|
|
|
|
|
|
|
title_info = self.embedding_search.title_index
|
|
|
|
self.conversation_info = ConversationModel(
|
|
|
|
user_id=self.user_id,
|
|
|
|
module="chatcomplete",
|
|
|
|
page_id=title_info.page_id,
|
|
|
|
rev_id=title_info.latest_rev_id,
|
|
|
|
extra={"bot_id": self.bot_id},
|
|
|
|
)
|
|
|
|
self.conversation_info = await self.conversation_helper.add(
|
|
|
|
self.conversation_info,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.conversation_chunk = ConversationChunkModel(
|
|
|
|
conversation_id=self.conversation_info.id,
|
|
|
|
message_data=init_message_data,
|
|
|
|
tokens=0,
|
|
|
|
)
|
|
|
|
self.conversation_chunk = await self.conversation_chunk_helper.add(
|
|
|
|
self.conversation_chunk
|
|
|
|
)
|
|
|
|
|
|
|
|
# Extract document from wiki page index
|
|
|
|
self.extract_doc = None
|
|
|
|
if embedding_search is not None:
|
|
|
|
self.extract_doc, token_usage = await self.embedding_search.search(
|
|
|
|
question, **embedding_search
|
|
|
|
)
|
|
|
|
if self.extract_doc is not None:
|
|
|
|
self.question_tokens += token_usage
|
|
|
|
|
|
|
|
return ChatCompleteServicePrepareResponse(
|
|
|
|
extract_doc=self.extract_doc,
|
|
|
|
question_tokens=self.question_tokens,
|
|
|
|
conversation_id=self.conversation_info.id,
|
|
|
|
chunk_id=self.conversation_chunk.id
|
|
|
|
)
|
|
|
|
|
|
|
|
async def finish_chat_complete(
|
|
|
|
self, on_message: Optional[callable] = None
|
|
|
|
) -> ChatCompleteServiceResponse:
|
|
|
|
delta_data = {}
|
|
|
|
|
|
|
|
message_log = []
|
|
|
|
if self.conversation_chunk is not None:
|
|
|
|
for message in self.conversation_chunk.message_data:
|
|
|
|
if message.get("role") in ["user", "assistant"]:
|
|
|
|
message_log.append(
|
|
|
|
{
|
|
|
|
"role": message["role"],
|
|
|
|
"content": message["content"],
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.extract_doc is not None:
|
|
|
|
doc_prompt_content = "\n".join(
|
|
|
|
[
|
|
|
|
"%d. %s" % (i + 1, doc["markdown"] or doc["text"])
|
|
|
|
for i, doc in enumerate(self.extract_doc)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
|
|
|
doc_prompt = utils.config.get_prompt(
|
|
|
|
"extracted_doc", "prompt", {"content": doc_prompt_content}
|
|
|
|
)
|
|
|
|
message_log.append({"role": "user", "content": doc_prompt})
|
|
|
|
|
|
|
|
system_prompt = self.chat_system_prompt
|
|
|
|
if system_prompt is None:
|
|
|
|
system_prompt = utils.config.get_prompt("default", "system")
|
|
|
|
|
|
|
|
if system_prompt is None:
|
|
|
|
raise Exception("System prompt not found.")
|
|
|
|
|
|
|
|
system_prompt = utils.config.format_prompt(system_prompt)
|
|
|
|
|
|
|
|
# Start chat complete
|
|
|
|
if on_message is not None:
|
|
|
|
response = await self.openai_api.chat_complete_stream(
|
|
|
|
self.question, system_prompt, self.model, message_log, on_message
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
response = await self.openai_api.chat_complete(
|
|
|
|
self.question, system_prompt, self.model, message_log
|
|
|
|
)
|
|
|
|
|
|
|
|
description = response["message"][0:150]
|
|
|
|
|
|
|
|
question_msg_id = utils.web.generate_uuid()
|
|
|
|
response_msg_id = utils.web.generate_uuid()
|
|
|
|
|
|
|
|
new_message_data = [
|
|
|
|
{
|
|
|
|
"id": question_msg_id,
|
|
|
|
"role": "user",
|
|
|
|
"content": self.question,
|
|
|
|
"tokens": self.question_tokens,
|
|
|
|
"time": self.conversation_start_time,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"id": response_msg_id,
|
|
|
|
"role": "assistant",
|
|
|
|
"content": response["message"],
|
|
|
|
"tokens": response["message_tokens"],
|
|
|
|
"time": int(time.time()),
|
|
|
|
},
|
|
|
|
]
|
|
|
|
|
|
|
|
if self.conversation_info is not None:
|
|
|
|
total_token_usage = self.question_tokens + response["message_tokens"]
|
|
|
|
# Generate title if not exists
|
|
|
|
if self.conversation_info.title is None:
|
|
|
|
title = None
|
|
|
|
try:
|
|
|
|
title, token_usage = await self.make_title(new_message_data)
|
|
|
|
delta_data["title"] = title
|
|
|
|
except Exception as e:
|
|
|
|
print(str(e), file=sys.stderr)
|
|
|
|
traceback.print_exc(file=sys.stderr)
|
|
|
|
|
|
|
|
self.conversation_info.title = title
|
|
|
|
|
|
|
|
# Update conversation info
|
|
|
|
self.conversation_info.description = description
|
|
|
|
|
|
|
|
await self.conversation_helper.update(self.conversation_info)
|
|
|
|
|
|
|
|
# Update conversation chunk
|
|
|
|
self.conversation_chunk.message_data.extend(new_message_data)
|
|
|
|
flag_modified(self.conversation_chunk, "message_data")
|
|
|
|
self.conversation_chunk.tokens += total_token_usage
|
|
|
|
|
|
|
|
await self.conversation_chunk_helper.update(self.conversation_chunk)
|
|
|
|
|
|
|
|
return ChatCompleteServiceResponse(
|
|
|
|
message=response["message"],
|
|
|
|
message_tokens=response["message_tokens"],
|
|
|
|
total_tokens=response["total_tokens"],
|
|
|
|
finish_reason=response["finish_reason"],
|
|
|
|
question_message_id=question_msg_id,
|
|
|
|
response_message_id=response_msg_id,
|
|
|
|
delta_data=delta_data,
|
|
|
|
)
|
|
|
|
|
|
|
|
async def set_latest_point_usage(self, point_usage: int) -> bool:
|
|
|
|
if self.conversation_chunk is None:
|
|
|
|
return False
|
|
|
|
|
|
|
|
if len(self.conversation_chunk.message_data) == 0:
|
|
|
|
return False
|
|
|
|
|
|
|
|
for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1):
|
|
|
|
if self.conversation_chunk.message_data[i].get("role") == "assistant":
|
|
|
|
self.conversation_chunk.message_data[i]["point_usage"] = point_usage
|
|
|
|
flag_modified(self.conversation_chunk, "message_data")
|
|
|
|
await self.conversation_chunk_helper.update(self.conversation_chunk)
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
async def make_summary(self, message_log_list: list) -> tuple[str, int]:
|
|
|
|
chat_log: list[str] = []
|
|
|
|
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
|
|
|
|
api_id = Config.get("chatcomplete.system_api_id", "default", str)
|
|
|
|
model_id = Config.get("chatcomplete.system_api_id", "default", str)
|
|
|
|
|
|
|
|
openai_api = OpenAIApi.create(api_id)
|
|
|
|
|
|
|
|
for message_data in message_log_list:
|
|
|
|
if "content" in message_data:
|
|
|
|
msg_role = message_data.get("role")
|
|
|
|
if msg_role == "summary":
|
|
|
|
chat_log.append(message_data["content"])
|
|
|
|
elif msg_role == "assistant":
|
|
|
|
chat_log.append(
|
|
|
|
f'{bot_name}: {message_data["content"]}'
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
chat_log.append(f'User: {message_data["content"]}')
|
|
|
|
|
|
|
|
chat_log_str = "\n".join(chat_log)
|
|
|
|
|
|
|
|
summary_system_prompt = utils.config.get_prompt("make_summary", "system")
|
|
|
|
summary_prompt = utils.config.get_prompt(
|
|
|
|
"make_summary", "prompt", {"content": chat_log_str}
|
|
|
|
)
|
|
|
|
|
|
|
|
response = await openai_api.chat_complete(
|
|
|
|
summary_prompt, summary_system_prompt, model_id
|
|
|
|
)
|
|
|
|
|
|
|
|
return response["message"], response["message_tokens"]
|
|
|
|
|
|
|
|
async def make_title(self, message_log_list: list) -> tuple[str, int]:
|
|
|
|
chat_log: list[str] = []
|
|
|
|
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
|
|
|
|
api_id = Config.get("chatcomplete.system_api_id", "default", str)
|
|
|
|
model_id = Config.get("chatcomplete.system_api_id", "default", str)
|
|
|
|
|
|
|
|
openai_api = OpenAIApi.create(api_id)
|
|
|
|
|
|
|
|
for message_data in message_log_list:
|
|
|
|
if "content" in message_data:
|
|
|
|
msg_role = message_data.get("role")
|
|
|
|
if msg_role == "summary":
|
|
|
|
chat_log.append(message_data["content"])
|
|
|
|
elif msg_role == "assistant":
|
|
|
|
chat_log.append(
|
|
|
|
f'{bot_name}: {message_data["content"]}'
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
chat_log.append(f'User: {message_data["content"]}')
|
|
|
|
|
|
|
|
chat_log_str = "\n".join(chat_log)
|
|
|
|
|
|
|
|
title_system_prompt = utils.config.get_prompt("make_title", "system")
|
|
|
|
title_prompt = utils.config.get_prompt(
|
|
|
|
"make_title", "prompt", {"content": chat_log_str}
|
|
|
|
)
|
|
|
|
|
|
|
|
response = await openai_api.chat_complete(
|
|
|
|
title_prompt, title_system_prompt, model_id
|
|
|
|
)
|
|
|
|
|
|
|
|
if response["message"] is None:
|
|
|
|
print(response)
|
|
|
|
raise Exception("Title generation failed")
|
|
|
|
|
|
|
|
return response["message"][0:250], response["message_tokens"]
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_point_usage(tokens: int, cost_fixed: float, cost_fixed_tokens: int, cost_per_token: float) -> int:
|
|
|
|
if tokens <= cost_fixed_tokens:
|
|
|
|
return cost_fixed
|
|
|
|
return cost_fixed + math.ceil((tokens - cost_fixed_tokens) * cost_per_token)
|