|
|
|
@ -19,6 +19,11 @@ from service.openai_api import OpenAIApi
|
|
|
|
|
from service.tiktoken import TikTokenService
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatCompleteServicePrepareResponse(TypedDict):
|
|
|
|
|
extract_doc: list
|
|
|
|
|
question_tokens: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatCompleteServiceResponse(TypedDict):
|
|
|
|
|
message: str
|
|
|
|
|
message_tokens: int
|
|
|
|
@ -44,9 +49,18 @@ class ChatCompleteService:
|
|
|
|
|
|
|
|
|
|
self.tiktoken: TikTokenService = None
|
|
|
|
|
|
|
|
|
|
self.extract_doc: list = None
|
|
|
|
|
|
|
|
|
|
self.mwapi = MediaWikiApi.create()
|
|
|
|
|
self.openai_api = OpenAIApi.create()
|
|
|
|
|
|
|
|
|
|
self.user_id = 0
|
|
|
|
|
self.question = ""
|
|
|
|
|
self.question_tokens: Optional[int] = None
|
|
|
|
|
self.conversation_id: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
self.delta_data = {}
|
|
|
|
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
|
|
self.tiktoken = await TikTokenService.create()
|
|
|
|
|
|
|
|
|
@ -67,26 +81,55 @@ class ChatCompleteService:
|
|
|
|
|
async def get_question_tokens(self, question: str):
|
|
|
|
|
return await self.tiktoken.get_tokens(question)
|
|
|
|
|
|
|
|
|
|
async def chat_complete(self, question: str, on_message: Optional[callable] = None, on_extracted_doc: Optional[callable] = None,
|
|
|
|
|
conversation_id: Optional[str] = None, user_id: Optional[int] = None, question_tokens: Optional[int] = None,
|
|
|
|
|
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServiceResponse:
|
|
|
|
|
async def prepare_chat_complete(self, question: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None,
|
|
|
|
|
question_tokens: Optional[int] = None,
|
|
|
|
|
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServiceResponse:
|
|
|
|
|
if user_id is not None:
|
|
|
|
|
user_id = int(user_id)
|
|
|
|
|
|
|
|
|
|
self.user_id = user_id
|
|
|
|
|
self.question = question
|
|
|
|
|
|
|
|
|
|
self.conversation_info = None
|
|
|
|
|
if conversation_id is not None:
|
|
|
|
|
conversation_id = int(conversation_id)
|
|
|
|
|
self.conversation_info = await self.conversation_helper.get_conversation(conversation_id)
|
|
|
|
|
self.conversation_id = int(conversation_id)
|
|
|
|
|
self.conversation_info = await self.conversation_helper.find_by_id(self.conversation_id)
|
|
|
|
|
else:
|
|
|
|
|
self.conversation_id = None
|
|
|
|
|
|
|
|
|
|
if self.conversation_info is not None:
|
|
|
|
|
if self.conversation_info.user_id != user_id:
|
|
|
|
|
raise web.HTTPUnauthorized()
|
|
|
|
|
|
|
|
|
|
if question_tokens is None:
|
|
|
|
|
self.question_tokens = await self.get_question_tokens(question)
|
|
|
|
|
else:
|
|
|
|
|
self.question_tokens = question_tokens
|
|
|
|
|
|
|
|
|
|
if (len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and
|
|
|
|
|
self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS):
|
|
|
|
|
# If the question is too long, we need to truncate it
|
|
|
|
|
raise web.HTTPRequestEntityTooLarge()
|
|
|
|
|
|
|
|
|
|
# Extract document from wiki page index
|
|
|
|
|
self.extract_doc = None
|
|
|
|
|
if embedding_search is not None:
|
|
|
|
|
self.extract_doc, token_usage = await self.embedding_search.search(question, **embedding_search)
|
|
|
|
|
if self.extract_doc is not None:
|
|
|
|
|
self.question_tokens += token_usage
|
|
|
|
|
|
|
|
|
|
return ChatCompleteServicePrepareResponse(
|
|
|
|
|
extract_doc=self.extract_doc,
|
|
|
|
|
question_tokens=self.question_tokens
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def finish_chat_complete(self, on_message: Optional[callable] = None) -> ChatCompleteServiceResponse:
|
|
|
|
|
delta_data = {}
|
|
|
|
|
|
|
|
|
|
self.conversation_chunk = None
|
|
|
|
|
message_log = []
|
|
|
|
|
if self.conversation_info is not None:
|
|
|
|
|
if self.conversation_info.user_id != user_id:
|
|
|
|
|
raise web.HTTPUnauthorized()
|
|
|
|
|
|
|
|
|
|
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(conversation_id)
|
|
|
|
|
self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(self.conversation_id)
|
|
|
|
|
|
|
|
|
|
# If the conversation is too long, we need to make a summary
|
|
|
|
|
if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS:
|
|
|
|
@ -95,9 +138,9 @@ class ChatCompleteService:
|
|
|
|
|
{"role": "summary", "content": summary, "tokens": tokens}
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
self.conversation_chunk = await self.conversation_chunk_helper.add(conversation_id, new_message_log, tokens)
|
|
|
|
|
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_id, new_message_log, tokens)
|
|
|
|
|
|
|
|
|
|
delta_data["conversation_chunk_id"] = self.conversation_chunk.id
|
|
|
|
|
self.delta_data["conversation_chunk_id"] = self.conversation_chunk.id
|
|
|
|
|
|
|
|
|
|
message_log = []
|
|
|
|
|
for message in self.conversation_chunk.message_data:
|
|
|
|
@ -106,40 +149,26 @@ class ChatCompleteService:
|
|
|
|
|
"content": message["content"],
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if question_tokens is None:
|
|
|
|
|
question_tokens = await self.get_question_tokens(question)
|
|
|
|
|
if (len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS and
|
|
|
|
|
question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS):
|
|
|
|
|
# If the question is too long, we need to truncate it
|
|
|
|
|
raise web.HTTPRequestEntityTooLarge()
|
|
|
|
|
|
|
|
|
|
extract_doc = None
|
|
|
|
|
if embedding_search is not None:
|
|
|
|
|
extract_doc, token_usage = await self.embedding_search.search(question, **embedding_search)
|
|
|
|
|
if extract_doc is not None:
|
|
|
|
|
if on_extracted_doc is not None:
|
|
|
|
|
await on_extracted_doc(extract_doc)
|
|
|
|
|
|
|
|
|
|
question_tokens = token_usage
|
|
|
|
|
doc_prompt_content = "\n".join(["%d. %s" % (
|
|
|
|
|
i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(extract_doc)])
|
|
|
|
|
if self.extract_doc is not None:
|
|
|
|
|
doc_prompt_content = "\n".join(["%d. %s" % (
|
|
|
|
|
i + 1, doc["markdown"] or doc["text"]) for i, doc in enumerate(self.extract_doc)])
|
|
|
|
|
|
|
|
|
|
doc_prompt = utils.config.get_prompt("extracted_doc", "prompt", {
|
|
|
|
|
"content": doc_prompt_content})
|
|
|
|
|
message_log.append({"role": "user", "content": doc_prompt})
|
|
|
|
|
doc_prompt = utils.config.get_prompt("extracted_doc", "prompt", {
|
|
|
|
|
"content": doc_prompt_content})
|
|
|
|
|
message_log.append({"role": "user", "content": doc_prompt})
|
|
|
|
|
|
|
|
|
|
system_prompt = utils.config.get_prompt("chat", "system_prompt")
|
|
|
|
|
|
|
|
|
|
# Start chat complete
|
|
|
|
|
if on_message is not None:
|
|
|
|
|
response = await self.openai_api.chat_complete_stream(question, system_prompt, message_log, on_message)
|
|
|
|
|
response = await self.openai_api.chat_complete_stream(self.question, system_prompt, message_log, on_message)
|
|
|
|
|
else:
|
|
|
|
|
response = await self.openai_api.chat_complete(question, system_prompt, message_log)
|
|
|
|
|
response = await self.openai_api.chat_complete(self.question, system_prompt, message_log)
|
|
|
|
|
|
|
|
|
|
if self.conversation_info is None:
|
|
|
|
|
# Create a new conversation
|
|
|
|
|
message_log_list = [
|
|
|
|
|
{"role": "user", "content": question, "tokens": question_tokens},
|
|
|
|
|
{"role": "user", "content": self.question, "tokens": self.question_tokens},
|
|
|
|
|
{"role": "assistant",
|
|
|
|
|
"content": response["message"], "tokens": response["message_tokens"]},
|
|
|
|
|
]
|
|
|
|
@ -152,21 +181,21 @@ class ChatCompleteService:
|
|
|
|
|
print(str(e), file=sys.stderr)
|
|
|
|
|
traceback.print_exc(file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
total_token_usage = question_tokens + response["message_tokens"]
|
|
|
|
|
total_token_usage = self.question_tokens + response["message_tokens"]
|
|
|
|
|
|
|
|
|
|
title_info = self.embedding_search.title_info
|
|
|
|
|
self.conversation_info = await self.conversation_helper.add(user_id, "chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], title=title)
|
|
|
|
|
self.conversation_info = await self.conversation_helper.add(self.user_id, "chatcomplete", page_id=title_info["page_id"], rev_id=title_info["rev_id"], title=title)
|
|
|
|
|
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_info.id, message_log_list, total_token_usage)
|
|
|
|
|
else:
|
|
|
|
|
# Update the conversation chunk
|
|
|
|
|
await self.conversation_helper.refresh_updated_at(conversation_id)
|
|
|
|
|
await self.conversation_helper.refresh_updated_at(self.conversation_id)
|
|
|
|
|
|
|
|
|
|
self.conversation_chunk.message_data.append(
|
|
|
|
|
{"role": "user", "content": question, "tokens": question_tokens})
|
|
|
|
|
{"role": "user", "content": self.question, "tokens": self.question_tokens})
|
|
|
|
|
self.conversation_chunk.message_data.append(
|
|
|
|
|
{"role": "assistant", "content": response["message"], "tokens": response["message_tokens"]})
|
|
|
|
|
flag_modified(self.conversation_chunk, "message_data")
|
|
|
|
|
self.conversation_chunk.tokens += question_tokens + \
|
|
|
|
|
self.conversation_chunk.tokens += self.question_tokens + \
|
|
|
|
|
response["message_tokens"]
|
|
|
|
|
|
|
|
|
|
await self.conversation_chunk_helper.update(self.conversation_chunk)
|
|
|
|
|