You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

410 lines
15 KiB
Python

from __future__ import annotations
import time
import traceback
from typing import Optional, Tuple, TypedDict
import sqlalchemy
from api.model.chat_complete.bot_persona import BotPersonaHelper
from api.model.chat_complete.conversation import (
ConversationChunkHelper,
ConversationChunkModel,
)
import sys
from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel
from config import Config
import utils.config, utils.web
from aiohttp import web
from api.model.embedding_search.title_collection import TitleCollectionModel
from sqlalchemy.orm.attributes import flag_modified
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs, EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService
class ChatCompleteServicePrepareResponse(TypedDict):
extract_doc: list
question_tokens: int
conversation_id: int
chunk_id: int
class ChatCompleteServiceResponse(TypedDict):
message: str
message_tokens: int
total_tokens: int
finish_reason: str
question_message_id: str
response_message_id: str
delta_data: dict
class ChatCompleteService:
def __init__(self, dbs: DatabaseService, title: str):
self.dbs = dbs
self.title = title
self.base_title = title.split("/")[0]
self.embedding_search = EmbeddingSearchService(dbs, title)
self.conversation_helper = ConversationHelper(dbs)
self.conversation_chunk_helper = ConversationChunkHelper(dbs)
self.conversation_info: Optional[ConversationModel] = None
self.conversation_chunk: Optional[ConversationChunkModel] = None
self.tiktoken: TikTokenService = None
self.extract_doc: list = None
self.mwapi = MediaWikiApi.create()
self.openai_api = OpenAIApi.create()
self.user_id = 0
self.question = ""
self.question_tokens: Optional[int] = None
self.conversation_id: Optional[int] = None
self.conversation_start_time: Optional[int] = None
self.delta_data = {}
async def __aenter__(self):
self.tiktoken = await TikTokenService.create()
await self.embedding_search.__aenter__()
await self.conversation_helper.__aenter__()
await self.conversation_chunk_helper.__aenter__()
return self
async def __aexit__(self, exc_type, exc, tb):
await self.embedding_search.__aexit__(exc_type, exc, tb)
await self.conversation_helper.__aexit__(exc_type, exc, tb)
await self.conversation_chunk_helper.__aexit__(exc_type, exc, tb)
async def page_index_exists(self):
return await self.embedding_search.page_index_exists(False)
async def get_question_tokens(self, question: str):
return await self.tiktoken.get_tokens(question)
async def prepare_chat_complete(
self,
question: str,
conversation_id: Optional[str] = None,
user_id: Optional[int] = None,
question_tokens: Optional[int] = None,
edit_message_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None,
) -> ChatCompleteServicePrepareResponse:
if user_id is not None:
user_id = int(user_id)
self.user_id = user_id
self.question = question
self.conversation_start_time = int(time.time())
self.conversation_info = None
if conversation_id is not None:
self.conversation_id = int(conversation_id)
self.conversation_info = await self.conversation_helper.find_by_id(
self.conversation_id
)
else:
self.conversation_id = None
if self.conversation_info is not None:
if self.conversation_info.user_id != user_id:
raise web.HTTPUnauthorized()
if question_tokens is None:
self.question_tokens = await self.get_question_tokens(question)
else:
self.question_tokens = question_tokens
max_input_tokens = Config.get("chatcomplete.max_input_tokens", 768, int)
if (
len(question) * 4 > max_input_tokens
and self.question_tokens > max_input_tokens
):
# If the question is too long, we need to truncate it
raise web.HTTPRequestEntityTooLarge()
self.conversation_chunk = None
if self.conversation_info is not None:
chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id)
if edit_message_id and "," in edit_message_id:
(edit_chunk_id, edit_msg_id) = edit_message_id.split(",")
edit_chunk_id = int(edit_chunk_id)
# Remove overrided conversation chunks
start_overrided = False
should_remove_chunk_ids = []
for chunk_id in chunk_id_list:
if start_overrided:
should_remove_chunk_ids.append(chunk_id)
else:
if chunk_id == edit_chunk_id:
start_overrided = True
if len(should_remove_chunk_ids) > 0:
await self.conversation_chunk_helper.remove(should_remove_chunk_ids)
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(
self.conversation_id
)
# Remove outdated message
edit_message_pos = None
old_tokens = 0
for i in range(0, len(self.conversation_chunk.message_data)):
msg_data: dict = self.conversation_chunk.message_data[i]
if msg_data.get("id") == edit_msg_id:
edit_message_pos = i
break
if "tokens" in msg_data and msg_data["tokens"] is not None:
old_tokens += msg_data["tokens"]
if edit_message_pos:
self.conversation_chunk.message_data = self.conversation_chunk.message_data[0:edit_message_pos]
flag_modified(self.conversation_chunk, "message_data")
self.conversation_chunk.tokens = old_tokens
else:
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(
self.conversation_id
)
# If the conversation is too long, we need to make a summary
max_memory_tokens = Config.get("chatcomplete.max_memory_tokens", 1280, int)
if self.conversation_chunk.tokens > max_memory_tokens:
summary, tokens = await self.make_summary(
self.conversation_chunk.message_data
)
new_message_log = [
{
"role": "summary",
"content": summary,
"tokens": tokens,
"time": int(time.time()),
}
]
self.conversation_chunk = ConversationChunkModel(
conversation_id=self.conversation_id,
message_data=new_message_log,
tokens=tokens,
)
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk)
else:
# 创建新对话
title_info = self.embedding_search.title_index
self.conversation_info = ConversationModel(
user_id=self.user_id,
module="chatcomplete",
page_id=title_info.page_id,
rev_id=title_info.latest_rev_id,
)
self.conversation_info = await self.conversation_helper.add(
self.conversation_info,
)
self.conversation_chunk = ConversationChunkModel(
conversation_id=self.conversation_info.id,
message_data=[],
tokens=0,
)
self.conversation_chunk = await self.conversation_chunk_helper.add(
self.conversation_chunk
)
# Extract document from wiki page index
self.extract_doc = None
if embedding_search is not None:
self.extract_doc, token_usage = await self.embedding_search.search(
question, **embedding_search
)
if self.extract_doc is not None:
self.question_tokens += token_usage
return ChatCompleteServicePrepareResponse(
extract_doc=self.extract_doc,
question_tokens=self.question_tokens,
conversation_id=self.conversation_info.id,
chunk_id=self.conversation_chunk.id
)
async def finish_chat_complete(
self, on_message: Optional[callable] = None
) -> ChatCompleteServiceResponse:
delta_data = {}
message_log = []
if self.conversation_chunk is not None:
for message in self.conversation_chunk.message_data:
if message["role"] in ["user", "assistant"]:
message_log.append(
{
"role": message["role"],
"content": message["content"],
}
)
if self.extract_doc is not None:
doc_prompt_content = "\n".join(
[
"%d. %s" % (i + 1, doc["markdown"] or doc["text"])
for i, doc in enumerate(self.extract_doc)
]
)
doc_prompt = utils.config.get_prompt(
"extracted_doc", "prompt", {"content": doc_prompt_content}
)
message_log.append({"role": "user", "content": doc_prompt})
bot_persona = self.conversation_info.extra.get("bot_persona") or "default"
system_prompt = await BotPersonaHelper.get_cached_system_prompt(self.dbs, bot_persona)
if system_prompt is None:
system_prompt = await BotPersonaHelper.get_cached_system_prompt(self.dbs, "default")
if system_prompt is None:
system_prompt = utils.config.get_prompt("default", "system")
if system_prompt is None:
raise Exception("System prompt not found.")
# Start chat complete
if on_message is not None:
response = await self.openai_api.chat_complete_stream(
self.question, system_prompt, message_log, on_message
)
else:
response = await self.openai_api.chat_complete(
self.question, system_prompt, message_log
)
description = response["message"][0:150]
question_msg_id = utils.web.generate_uuid()
response_msg_id = utils.web.generate_uuid()
new_message_data = [
{
"id": question_msg_id,
"role": "user",
"content": self.question,
"tokens": self.question_tokens,
"time": self.conversation_start_time,
},
{
"id": response_msg_id,
"role": "assistant",
"content": response["message"],
"tokens": response["message_tokens"],
"time": int(time.time()),
},
]
if self.conversation_info is not None:
total_token_usage = self.question_tokens + response["message_tokens"]
# Generate title if not exists
if self.conversation_info.title is None:
title = None
try:
title, token_usage = await self.make_title(new_message_data)
delta_data["title"] = title
except Exception as e:
print(str(e), file=sys.stderr)
traceback.print_exc(file=sys.stderr)
self.conversation_info.title = title
# Update conversation info
self.conversation_info.description = description
await self.conversation_helper.update(self.conversation_info)
# Update conversation chunk
self.conversation_chunk.message_data.extend(new_message_data)
flag_modified(self.conversation_chunk, "message_data")
self.conversation_chunk.tokens += total_token_usage
await self.conversation_chunk_helper.update(self.conversation_chunk)
return ChatCompleteServiceResponse(
message=response["message"],
message_tokens=response["message_tokens"],
total_tokens=response["total_tokens"],
finish_reason=response["finish_reason"],
question_message_id=question_msg_id,
response_message_id=response_msg_id,
delta_data=delta_data,
)
async def set_latest_point_cost(self, point_cost: int) -> bool:
if self.conversation_chunk is None:
return False
if len(self.conversation_chunk.message_data) == 0:
return False
for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1):
if self.conversation_chunk.message_data[i]["role"] == "assistant":
self.conversation_chunk.message_data[i]["point_cost"] = point_cost
flag_modified(self.conversation_chunk, "message_data")
await self.conversation_chunk_helper.update(self.conversation_chunk)
return True
async def make_summary(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
for message_data in message_log_list:
if message_data["role"] == "summary":
chat_log.append(message_data["content"])
elif message_data["role"] == "assistant":
chat_log.append(
f'{bot_name}: {message_data["content"]}'
)
else:
chat_log.append(f'User: {message_data["content"]}')
chat_log_str = "\n".join(chat_log)
summary_system_prompt = utils.config.get_prompt("make_summary", "system")
summary_prompt = utils.config.get_prompt(
"make_summary", "prompt", {"content": chat_log_str}
)
response = await self.openai_api.chat_complete(
summary_prompt, summary_system_prompt
)
return response["message"], response["message_tokens"]
async def make_title(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
for message_data in message_log_list:
if message_data["role"] == "assistant":
chat_log.append(
f'{bot_name}: {message_data["content"]}'
)
elif message_data["role"] == "user":
chat_log.append(f'User: {message_data["content"]}')
chat_log_str = "\n".join(chat_log)
title_system_prompt = utils.config.get_prompt("make_title", "system")
title_prompt = utils.config.get_prompt(
"make_title", "prompt", {"content": chat_log_str}
)
response = await self.openai_api.chat_complete(
title_prompt, title_system_prompt
)
return response["message"], response["message_tokens"]