from __future__ import annotations import traceback from typing import Optional, Tuple, TypedDict from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationChunkModel import sys from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel import config import utils.config from aiohttp import web from api.model.embedding_search.title_collection import TitleCollectionModel from sqlalchemy.orm.attributes import flag_modified from service.database import DatabaseService from service.embedding_search import EmbeddingSearchArgs, EmbeddingSearchService from service.mediawiki_api import MediaWikiApi from service.openai_api import OpenAIApi from service.tiktoken import TikTokenService class ChatCompleteServicePrepareResponse(TypedDict): extract_doc: list question_tokens: int class ChatCompleteServiceResponse(TypedDict): message: str message_tokens: int total_tokens: int finish_reason: str conversation_id: int delta_data: dict class ChatCompleteService: def __init__(self, dbs: DatabaseService, title: str): self.dbs = dbs self.title = title self.base_title = title.split("/")[0] self.embedding_search = EmbeddingSearchService(dbs, title) self.conversation_helper = ConversationHelper(dbs) self.conversation_chunk_helper = ConversationChunkHelper(dbs) self.conversation_info: Optional[ConversationModel] = None self.conversation_chunk: Optional[ConversationChunkModel] = None self.tiktoken: TikTokenService = None self.extract_doc: list = None self.mwapi = MediaWikiApi.create() self.openai_api = OpenAIApi.create() self.user_id = 0 self.question = "" self.question_tokens: Optional[int] = None self.conversation_id: Optional[int] = None self.delta_data = {} async def __aenter__(self): self.tiktoken = await TikTokenService.create() await self.embedding_search.__aenter__() await self.conversation_helper.__aenter__() await self.conversation_chunk_helper.__aenter__() return self async def __aexit__(self, exc_type, exc, tb): await self.embedding_search.__aexit__(exc_type, exc, tb) await self.conversation_helper.__aexit__(exc_type, exc, tb) await self.conversation_chunk_helper.__aexit__(exc_type, exc, tb) async def page_index_exists(self): return await self.embedding_search.page_index_exists(False) async def get_question_tokens(self, question: str): return await self.tiktoken.get_tokens(question) async def prepare_chat_complete(self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None, question_tokens: Optional[int] = None, embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse: if user_id is not None: user_id = int(user_id) self.user_id = user_id self.question = question self.conversation_info = None if conversation_id is not None: self.conversation_id = int(conversation_id) self.conversation_info = await self.conversation_helper.find_by_id(self.conversation_id) else: self.conversation_id = None if self.conversation_info is not None: if self.conversation_info.user_id != user_id: raise web.HTTPUnauthorized() if question_tokens is None: self.question_tokens = await self.get_question_tokens(question) else: self.question_tokens = question_tokens if (len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS): # If the question is too long, we need to truncate it raise web.HTTPRequestEntityTooLarge() # Extract document from wiki page index self.extract_doc = None if embedding_search is not None: self.extract_doc, token_usage = await self.embedding_search.search(question, **embedding_search) if self.extract_doc is not None: self.question_tokens += token_usage return ChatCompleteServicePrepareResponse( extract_doc=self.extract_doc, question_tokens=self.question_tokens ) async def finish_chat_complete(self, on_message: Optional[callable] = None) -> ChatCompleteServiceResponse: delta_data = {} self.conversation_chunk = None message_log = [] if self.conversation_info is not None: self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(self.conversation_id) # If the conversation is too long, we need to make a summary if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS: summary, tokens = await self.make_summary(self.conversation_chunk.message_data) new_message_log = [ {"role": "summary", "content": summary, "tokens": tokens} ] self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_id, new_message_log, tokens) self.delta_data["conversation_chunk_id"] = self.conversation_chunk.id message_log = [] for message in self.conversation_chunk.message_data: message_log.append({ "role": message["role"], "content": message["content"], }) if self.extract_doc is not None: doc_prompt_content = "\n".join(["%d. %s" % ( i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(self.extract_doc)]) doc_prompt = utils.config.get_prompt("extracted_doc", "prompt", { "content": doc_prompt_content}) message_log.append({"role": "user", "content": doc_prompt}) system_prompt = utils.config.get_prompt("chat", "system_prompt") # Start chat complete if on_message is not None: response = await self.openai_api.chat_complete_stream(self.question, system_prompt, message_log, on_message) else: response = await self.openai_api.chat_complete(self.question, system_prompt, message_log) if self.conversation_info is None: # Create a new conversation message_log_list = [ {"role": "user", "content": self.question, "tokens": self.question_tokens}, {"role": "assistant", "content": response["message"], "tokens": response["message_tokens"]}, ] title = None try: title, token_usage = await self.make_title(message_log_list) delta_data["title"] = title except Exception as e: title = config.CHATCOMPLETE_DEFAULT_CONVERSATION_TITLE print(str(e), file=sys.stderr) traceback.print_exc(file=sys.stderr) total_token_usage = self.question_tokens + response["message_tokens"] title_info = self.embedding_search.title_info self.conversation_info = await self.conversation_helper.add(self.user_id, "chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], title=title) self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_info.id, message_log_list, total_token_usage) else: # Update the conversation chunk await self.conversation_helper.refresh_updated_at(self.conversation_id) self.conversation_chunk.message_data.append( {"role": "user", "content": self.question, "tokens": self.question_tokens}) self.conversation_chunk.message_data.append( {"role": "assistant", "content": response["message"], "tokens": response["message_tokens"]}) flag_modified(self.conversation_chunk, "message_data") self.conversation_chunk.tokens += self.question_tokens + \ response["message_tokens"] await self.conversation_chunk_helper.update(self.conversation_chunk) return ChatCompleteServiceResponse( message=response["message"], message_tokens=response["message_tokens"], total_tokens=response["total_tokens"], finish_reason=response["finish_reason"], conversation_id=self.conversation_info.id, delta_data=delta_data ) async def set_latest_point_cost(self, point_cost: int) -> bool: if self.conversation_chunk is None: return False if len(self.conversation_chunk.message_data) == 0: return False for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1): if self.conversation_chunk.message_data[i]["role"] == "assistant": self.conversation_chunk.message_data[i]["point_cost"] = point_cost flag_modified(self.conversation_chunk, "message_data") await self.conversation_chunk_helper.update(self.conversation_chunk) return True async def make_summary(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] for message_data in message_log_list: if message_data["role"] == 'summary': chat_log.append(message_data["content"]) elif message_data["role"] == 'assistant': chat_log.append( f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}') else: chat_log.append(f'User: {message_data["content"]}') chat_log_str = '\n'.join(chat_log) summary_system_prompt = utils.config.get_prompt( "summary", "system_prompt") summary_prompt = utils.config.get_prompt( "summary", "prompt", {"content": chat_log_str}) response = await self.openai_api.chat_complete(summary_prompt, summary_system_prompt) return response["message"], response["message_tokens"] async def make_title(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] for message_data in message_log_list: if message_data["role"] == 'assistant': chat_log.append( f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}') elif message_data["role"] == 'user': chat_log.append(f'User: {message_data["content"]}') chat_log_str = '\n'.join(chat_log) title_system_prompt = utils.config.get_prompt("title", "system_prompt") title_prompt = utils.config.get_prompt( "title", "prompt", {"content": chat_log_str}) response = await self.openai_api.chat_complete(title_prompt, title_system_prompt) return response["message"], response["message_tokens"]