from __future__ import annotations import time import traceback from typing import Optional, Tuple, TypedDict import sqlalchemy from api.model.chat_complete.conversation import ( ConversationChunkHelper, ConversationChunkModel, ) import sys from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel import config import utils.config, utils.web from aiohttp import web from api.model.embedding_search.title_collection import TitleCollectionModel from sqlalchemy.orm.attributes import flag_modified from service.database import DatabaseService from service.embedding_search import EmbeddingSearchArgs, EmbeddingSearchService from service.mediawiki_api import MediaWikiApi from service.openai_api import OpenAIApi from service.tiktoken import TikTokenService class ChatCompleteServicePrepareResponse(TypedDict): extract_doc: list question_tokens: int conversation_id: int chunk_id: int class ChatCompleteServiceResponse(TypedDict): message: str message_tokens: int total_tokens: int finish_reason: str question_message_id: str response_message_id: str delta_data: dict class ChatCompleteService: def __init__(self, dbs: DatabaseService, title: str): self.dbs = dbs self.title = title self.base_title = title.split("/")[0] self.embedding_search = EmbeddingSearchService(dbs, title) self.conversation_helper = ConversationHelper(dbs) self.conversation_chunk_helper = ConversationChunkHelper(dbs) self.conversation_info: Optional[ConversationModel] = None self.conversation_chunk: Optional[ConversationChunkModel] = None self.tiktoken: TikTokenService = None self.extract_doc: list = None self.mwapi = MediaWikiApi.create() self.openai_api = OpenAIApi.create() self.user_id = 0 self.question = "" self.question_tokens: Optional[int] = None self.conversation_id: Optional[int] = None self.conversation_start_time: Optional[int] = None self.delta_data = {} async def __aenter__(self): self.tiktoken = await TikTokenService.create() await self.embedding_search.__aenter__() await self.conversation_helper.__aenter__() await self.conversation_chunk_helper.__aenter__() return self async def __aexit__(self, exc_type, exc, tb): await self.embedding_search.__aexit__(exc_type, exc, tb) await self.conversation_helper.__aexit__(exc_type, exc, tb) await self.conversation_chunk_helper.__aexit__(exc_type, exc, tb) async def page_index_exists(self): return await self.embedding_search.page_index_exists(False) async def get_question_tokens(self, question: str): return await self.tiktoken.get_tokens(question) async def prepare_chat_complete( self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None, question_tokens: Optional[int] = None, edit_message_id: Optional[str] = None, embedding_search: Optional[EmbeddingSearchArgs] = None, ) -> ChatCompleteServicePrepareResponse: if user_id is not None: user_id = int(user_id) self.user_id = user_id self.question = question self.conversation_start_time = int(time.time()) self.conversation_info = None if conversation_id is not None: self.conversation_id = int(conversation_id) self.conversation_info = await self.conversation_helper.find_by_id( self.conversation_id ) else: self.conversation_id = None if self.conversation_info is not None: if self.conversation_info.user_id != user_id: raise web.HTTPUnauthorized() if question_tokens is None: self.question_tokens = await self.get_question_tokens(question) else: self.question_tokens = question_tokens if ( len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS ): # If the question is too long, we need to truncate it raise web.HTTPRequestEntityTooLarge() self.conversation_chunk = None if self.conversation_info is not None: chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id) if edit_message_id and "," in edit_message_id: (edit_chunk_id, edit_msg_id) = edit_message_id.split(",") edit_chunk_id = int(edit_chunk_id) # Remove overrided conversation chunks start_overrided = False should_remove_chunk_ids = [] for chunk_id in chunk_id_list: if start_overrided: should_remove_chunk_ids.append(chunk_id) else: if chunk_id == edit_chunk_id: start_overrided = True if len(should_remove_chunk_ids) > 0: await self.conversation_chunk_helper.remove(should_remove_chunk_ids) self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk( self.conversation_id ) # Remove outdated message edit_message_pos = None old_tokens = 0 for i in range(0, len(self.conversation_chunk.message_data)): msg_data = self.conversation_chunk.message_data[i] if msg_data["id"] == edit_msg_id: edit_message_pos = i break if "tokens" in msg_data and msg_data["tokens"]: old_tokens += msg_data["tokens"] if edit_message_pos: self.conversation_chunk.message_data = self.conversation_chunk.message_data[0:edit_message_pos] flag_modified(self.conversation_chunk, "message_data") self.conversation_chunk.tokens = old_tokens else: self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk( self.conversation_id ) # If the conversation is too long, we need to make a summary if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS: summary, tokens = await self.make_summary( self.conversation_chunk.message_data ) new_message_log = [ { "role": "summary", "content": summary, "tokens": tokens, "time": int(time.time()), } ] self.conversation_chunk = ConversationChunkModel( conversation_id=self.conversation_id, message_data=new_message_log, tokens=tokens, ) self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk) else: # 创建新对话 title_info = self.embedding_search.title_info self.conversation_info = ConversationModel( user_id=self.user_id, module="chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], ) self.conversation_info = await self.conversation_helper.add( self.conversation_info, ) self.conversation_chunk = ConversationChunkModel( conversation_id=self.conversation_info.id, message_data=[], tokens=0, ) self.conversation_chunk = await self.conversation_chunk_helper.add( self.conversation_chunk ) # Extract document from wiki page index self.extract_doc = None if embedding_search is not None: self.extract_doc, token_usage = await self.embedding_search.search( question, **embedding_search ) if self.extract_doc is not None: self.question_tokens += token_usage return ChatCompleteServicePrepareResponse( extract_doc=self.extract_doc, question_tokens=self.question_tokens, conversation_id=self.conversation_info.id, chunk_id=self.conversation_chunk.id ) async def finish_chat_complete( self, on_message: Optional[callable] = None ) -> ChatCompleteServiceResponse: delta_data = {} message_log = [] if self.conversation_chunk is not None: for message in self.conversation_chunk.message_data: message_log.append( { "role": message["role"], "content": message["content"], } ) if self.extract_doc is not None: doc_prompt_content = "\n".join( [ "%d. %s" % (i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(self.extract_doc) ] ) doc_prompt = utils.config.get_prompt( "extracted_doc", "prompt", {"content": doc_prompt_content} ) message_log.append({"role": "user", "content": doc_prompt}) system_prompt = utils.config.get_prompt("chat", "system_prompt") # Start chat complete if on_message is not None: response = await self.openai_api.chat_complete_stream( self.question, system_prompt, message_log, on_message ) else: response = await self.openai_api.chat_complete( self.question, system_prompt, message_log ) description = response["message"][0:150] question_msg_id = utils.web.generate_uuid() response_msg_id = utils.web.generate_uuid() new_message_data = [ { "id": question_msg_id, "role": "user", "content": self.question, "tokens": self.question_tokens, "time": self.conversation_start_time, }, { "id": response_msg_id, "role": "assistant", "content": response["message"], "tokens": response["message_tokens"], "time": int(time.time()), }, ] if self.conversation_info is not None: total_token_usage = self.question_tokens + response["message_tokens"] # Generate title if not exists if self.conversation_info.title is None: title = None try: title, token_usage = await self.make_title(new_message_data) delta_data["title"] = title except Exception as e: print(str(e), file=sys.stderr) traceback.print_exc(file=sys.stderr) self.conversation_info.title = title # Update conversation info self.conversation_info.description = description await self.conversation_helper.update(self.conversation_info) # Update conversation chunk self.conversation_chunk.message_data.extend(new_message_data) flag_modified(self.conversation_chunk, "message_data") self.conversation_chunk.tokens += total_token_usage await self.conversation_chunk_helper.update(self.conversation_chunk) return ChatCompleteServiceResponse( message=response["message"], message_tokens=response["message_tokens"], total_tokens=response["total_tokens"], finish_reason=response["finish_reason"], question_message_id=question_msg_id, response_message_id=response_msg_id, delta_data=delta_data, ) async def set_latest_point_cost(self, point_cost: int) -> bool: if self.conversation_chunk is None: return False if len(self.conversation_chunk.message_data) == 0: return False for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1): if self.conversation_chunk.message_data[i]["role"] == "assistant": self.conversation_chunk.message_data[i]["point_cost"] = point_cost flag_modified(self.conversation_chunk, "message_data") await self.conversation_chunk_helper.update(self.conversation_chunk) return True async def make_summary(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] for message_data in message_log_list: if message_data["role"] == "summary": chat_log.append(message_data["content"]) elif message_data["role"] == "assistant": chat_log.append( f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}' ) else: chat_log.append(f'User: {message_data["content"]}') chat_log_str = "\n".join(chat_log) summary_system_prompt = utils.config.get_prompt("summary", "system_prompt") summary_prompt = utils.config.get_prompt( "summary", "prompt", {"content": chat_log_str} ) response = await self.openai_api.chat_complete( summary_prompt, summary_system_prompt ) return response["message"], response["message_tokens"] async def make_title(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] for message_data in message_log_list: if message_data["role"] == "assistant": chat_log.append( f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}' ) elif message_data["role"] == "user": chat_log.append(f'User: {message_data["content"]}') chat_log_str = "\n".join(chat_log) title_system_prompt = utils.config.get_prompt("title", "system_prompt") title_prompt = utils.config.get_prompt( "title", "prompt", {"content": chat_log_str} ) response = await self.openai_api.chat_complete( title_prompt, title_system_prompt ) return response["message"], response["message_tokens"]