from __future__ import annotations import math import time import traceback from typing import Optional, Tuple, TypedDict from server.model.chat_complete.bot_persona import BotPersonaHelper from server.model.chat_complete.conversation import ( ConversationChunkHelper, ConversationChunkModel, ) import sys from server.model.toolbox_ui.conversation import ConversationHelper, ConversationModel from libs.config import Config import utils.config, utils.web from aiohttp import web from sqlalchemy.orm.attributes import flag_modified from service.database import DatabaseService from service.embedding_search import EmbeddingSearchArgs, EmbeddingSearchService from service.mediawiki_api import MediaWikiApi from service.openai_api import OpenAIApi, OpenAIApiTypeInvalidException from service.tiktoken import TikTokenService class ChatCompleteQuestionTooLongException(Exception): def __init__(self, tokens_limit: int, tokens_current: int): super().__init__(f"Question too long: {tokens_current} > {tokens_limit}") self.tokens_limit = tokens_limit self.tokens_current = tokens_current class ChatCompleteServicePrepareResponse(TypedDict): extract_doc: list question_tokens: int conversation_id: int chunk_id: int api_id: str class ChatCompleteServiceResponse(TypedDict): message: str message_tokens: int total_tokens: int finish_reason: str question_message_id: str response_message_id: str delta_data: dict class ChatCompleteService: def __init__(self, dbs: DatabaseService, title: str): self.dbs = dbs self.title = title self.base_title = title.split("/")[0] self.embedding_search = EmbeddingSearchService(dbs, title) self.conversation_helper = ConversationHelper(dbs) self.conversation_chunk_helper = ConversationChunkHelper(dbs) self.bot_persona_helper = BotPersonaHelper(dbs) self.conversation_info: Optional[ConversationModel] = None self.conversation_chunk: Optional[ConversationChunkModel] = None self.openai_api: OpenAIApi = None self.tiktoken: TikTokenService = None self.extract_doc: list = None self.mwapi = MediaWikiApi.create() self.user_id = 0 self.question = "" self.question_tokens: Optional[int] = None self.bot_id: str = "" self.model: Optional[str] = None self.conversation_id: Optional[int] = None self.conversation_start_time: Optional[int] = None self.chat_system_prompt = "" self.delta_data = {} async def __aenter__(self): self.tiktoken = await TikTokenService.create() await self.embedding_search.__aenter__() await self.conversation_helper.__aenter__() await self.conversation_chunk_helper.__aenter__() await self.bot_persona_helper.__aenter__() return self async def __aexit__(self, exc_type, exc, tb): await self.embedding_search.__aexit__(exc_type, exc, tb) await self.conversation_helper.__aexit__(exc_type, exc, tb) await self.conversation_chunk_helper.__aexit__(exc_type, exc, tb) await self.bot_persona_helper.__aexit__(exc_type, exc, tb) async def page_index_exists(self): return await self.embedding_search.page_index_exists(False) async def get_question_tokens(self, question: str): return await self.tiktoken.get_tokens(question) async def prepare_chat_complete( self, question: str, bot_id: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None, question_tokens: Optional[int] = None, edit_message_id: Optional[str] = None, embedding_search: Optional[EmbeddingSearchArgs] = None, ) -> ChatCompleteServicePrepareResponse: if user_id is not None: user_id = int(user_id) self.user_id = user_id self.question = question self.conversation_start_time = int(time.time()) self.bot_id = bot_id self.conversation_info = None if conversation_id is not None: self.conversation_id = int(conversation_id) self.conversation_info = await self.conversation_helper.find_by_id( self.conversation_id ) else: self.conversation_id = None if self.conversation_info is not None: if self.conversation_info.user_id != user_id: raise web.HTTPUnauthorized() if question_tokens is None: self.question_tokens = await self.get_question_tokens(question) else: self.question_tokens = question_tokens if self.question_tokens == 0: self.question_tokens = len(question) * 3 max_input_tokens = Config.get("chatcomplete.max_input_tokens", 768, int) if ( self.question_tokens > max_input_tokens ): # If the question is too long, we need to truncate it raise ChatCompleteQuestionTooLongException(max_input_tokens, self.question_tokens) if self.conversation_info is not None: self.bot_id = self.conversation_info.extra.get("bot_id") or "default" bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id) if bot_persona is None: self.bot_id = "default" self.model = None bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id) else: self.model = bot_persona.model_id self.chat_system_prompt = bot_persona.system_prompt default_api = Config.get("chatcomplete.default_api", None, str) try: self.openai_api = OpenAIApi.create(bot_persona.api_id or default_api) except OpenAIApiTypeInvalidException: print(f"Invalid API type: {bot_persona.api_id}", file=sys.stderr) self.openai_api = OpenAIApi.create(default_api) self.conversation_chunk = None if self.conversation_info is not None: chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id) if edit_message_id and "," in edit_message_id: (edit_chunk_id, edit_msg_id) = edit_message_id.split(",") edit_chunk_id = int(edit_chunk_id) # Remove overrided conversation chunks start_overrided = False should_remove_chunk_ids = [] for chunk_id in chunk_id_list: if start_overrided: should_remove_chunk_ids.append(chunk_id) else: if chunk_id == edit_chunk_id: start_overrided = True if len(should_remove_chunk_ids) > 0: await self.conversation_chunk_helper.remove(should_remove_chunk_ids) self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk( self.conversation_id ) # Remove outdated message edit_message_pos = None old_tokens = 0 for i in range(0, len(self.conversation_chunk.message_data)): msg_data: dict = self.conversation_chunk.message_data[i] if msg_data.get("id") == edit_msg_id: edit_message_pos = i break if "tokens" in msg_data and msg_data["tokens"] is not None: old_tokens += msg_data["tokens"] if edit_message_pos: self.conversation_chunk.message_data = self.conversation_chunk.message_data[0:edit_message_pos] flag_modified(self.conversation_chunk, "message_data") self.conversation_chunk.tokens = old_tokens else: self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk( self.conversation_id ) # If the conversation is too long, we need to make a summary max_memory_tokens = Config.get("chatcomplete.max_memory_tokens", 1280, int) if self.conversation_chunk.tokens > max_memory_tokens: summary, tokens = await self.make_summary( self.conversation_chunk.message_data ) new_message_log = [ { "role": "summary", "content": summary, "tokens": tokens, "time": int(time.time()), } ] self.conversation_chunk = ConversationChunkModel( conversation_id=self.conversation_id, message_data=new_message_log, tokens=tokens, ) self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk) else: # 创建新对话 # 默认聊天记录 init_message_data = [] if bot_persona is not None: current_time = int(time.time()) if bot_persona.message_log is not None: for message in bot_persona.message_log: message["id"] = utils.web.generate_uuid() message["time"] = current_time init_message_data.append(message) title_info = self.embedding_search.title_index self.conversation_info = ConversationModel( user_id=self.user_id, module="chatcomplete", page_id=title_info.page_id, rev_id=title_info.latest_rev_id, extra={"bot_id": self.bot_id}, ) self.conversation_info = await self.conversation_helper.add( self.conversation_info, ) self.conversation_chunk = ConversationChunkModel( conversation_id=self.conversation_info.id, message_data=init_message_data, tokens=0, ) self.conversation_chunk = await self.conversation_chunk_helper.add( self.conversation_chunk ) # Extract document from wiki page index self.extract_doc = None if embedding_search is not None: self.extract_doc, token_usage = await self.embedding_search.search( question, **embedding_search ) if self.extract_doc is not None: self.question_tokens += token_usage return ChatCompleteServicePrepareResponse( extract_doc=self.extract_doc, question_tokens=self.question_tokens, conversation_id=self.conversation_info.id, chunk_id=self.conversation_chunk.id ) async def finish_chat_complete( self, on_message: Optional[callable] = None ) -> ChatCompleteServiceResponse: delta_data = {} message_log = [] if self.conversation_chunk is not None: for message in self.conversation_chunk.message_data: if message.get("role") in ["user", "assistant"]: message_log.append( { "role": message["role"], "content": message["content"], } ) if self.extract_doc is not None: doc_prompt_content = "\n".join( [ "%d. %s" % (i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(self.extract_doc) ] ) doc_prompt = utils.config.get_prompt( "extracted_doc", "prompt", {"content": doc_prompt_content} ) message_log.append({"role": "user", "content": doc_prompt}) system_prompt = self.chat_system_prompt if system_prompt is None: system_prompt = utils.config.get_prompt("default", "system") if system_prompt is None: raise Exception("System prompt not found.") system_prompt = utils.config.format_prompt(system_prompt) # Start chat complete if on_message is not None: response = await self.openai_api.chat_complete_stream( self.question, system_prompt, self.model, message_log, on_message ) else: response = await self.openai_api.chat_complete( self.question, system_prompt, self.model, message_log ) description = response["message"][0:150] question_msg_id = utils.web.generate_uuid() response_msg_id = utils.web.generate_uuid() new_message_data = [ { "id": question_msg_id, "role": "user", "content": self.question, "tokens": self.question_tokens, "time": self.conversation_start_time, }, { "id": response_msg_id, "role": "assistant", "content": response["message"], "tokens": response["message_tokens"], "time": int(time.time()), }, ] if self.conversation_info is not None: total_token_usage = self.question_tokens + response["message_tokens"] # Generate title if not exists if self.conversation_info.title is None: title = None try: title, token_usage = await self.make_title(new_message_data) delta_data["title"] = title except Exception as e: print(str(e), file=sys.stderr) traceback.print_exc(file=sys.stderr) self.conversation_info.title = title # Update conversation info self.conversation_info.description = description await self.conversation_helper.update(self.conversation_info) # Update conversation chunk self.conversation_chunk.message_data.extend(new_message_data) flag_modified(self.conversation_chunk, "message_data") self.conversation_chunk.tokens += total_token_usage await self.conversation_chunk_helper.update(self.conversation_chunk) return ChatCompleteServiceResponse( message=response["message"], message_tokens=response["message_tokens"], total_tokens=response["total_tokens"], finish_reason=response["finish_reason"], question_message_id=question_msg_id, response_message_id=response_msg_id, delta_data=delta_data, ) async def set_latest_point_usage(self, point_usage: int) -> bool: if self.conversation_chunk is None: return False if len(self.conversation_chunk.message_data) == 0: return False for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1): if self.conversation_chunk.message_data[i].get("role") == "assistant": self.conversation_chunk.message_data[i]["point_usage"] = point_usage flag_modified(self.conversation_chunk, "message_data") await self.conversation_chunk_helper.update(self.conversation_chunk) return True async def make_summary(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str) api_id = Config.get("chatcomplete.system_api_id", "default", str) model_id = Config.get("chatcomplete.system_model_id", "default", str) openai_api = OpenAIApi.create(api_id) for message_data in message_log_list: if "content" in message_data: msg_role = message_data.get("role") if msg_role == "summary": chat_log.append(message_data["content"]) elif msg_role == "assistant": chat_log.append( f'{bot_name}: {message_data["content"]}' ) else: chat_log.append(f'User: {message_data["content"]}') chat_log_str = "\n".join(chat_log) summary_system_prompt = utils.config.get_prompt("make_summary", "system") summary_prompt = utils.config.get_prompt( "make_summary", "prompt", {"content": chat_log_str} ) response = await openai_api.chat_complete( summary_prompt, summary_system_prompt, model_id ) return response["message"], response["message_tokens"] async def make_title(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str) api_id = Config.get("chatcomplete.system_api_id", "default", str) model_id = Config.get("chatcomplete.system_model_id", "default", str) openai_api = OpenAIApi.create(api_id) for message_data in message_log_list: if "content" in message_data: msg_role = message_data.get("role") if msg_role == "summary": chat_log.append(message_data["content"]) elif msg_role == "assistant": chat_log.append( f'{bot_name}: {message_data["content"]}' ) else: chat_log.append(f'User: {message_data["content"]}') chat_log_str = "\n".join(chat_log) title_system_prompt = utils.config.get_prompt("make_title", "system") title_prompt = utils.config.get_prompt( "make_title", "prompt", {"content": chat_log_str} ) response = await openai_api.chat_complete( title_prompt, title_system_prompt, model_id ) if response["message"] is None: print(response) raise Exception("Title generation failed") return response["message"][0:250], response["message_tokens"] def calculate_point_usage(tokens: int, cost_fixed: float, cost_fixed_tokens: int, cost_per_token: float) -> int: if tokens <= cost_fixed_tokens: return cost_fixed return cost_fixed + math.ceil((tokens - cost_fixed_tokens) * cost_per_token)