|
|
|
@ -50,26 +50,29 @@ class ChatCompleteTask:
|
|
|
|
|
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
|
|
|
|
|
self.chat_complete = await self.chat_complete_service.__aenter__()
|
|
|
|
|
|
|
|
|
|
if await self.chat_complete.page_index_exists():
|
|
|
|
|
question_tokens = await self.tiktoken.get_tokens(question)
|
|
|
|
|
|
|
|
|
|
extract_limit = embedding_search["limit"] or 10
|
|
|
|
|
|
|
|
|
|
self.transatcion_id: Optional[str] = None
|
|
|
|
|
self.point_cost: int = 0
|
|
|
|
|
if not self.is_system:
|
|
|
|
|
usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete",
|
|
|
|
|
question_tokens, extract_limit)
|
|
|
|
|
self.transatcion_id = usage_res["transaction_id"]
|
|
|
|
|
self.point_cost = usage_res["point_cost"]
|
|
|
|
|
|
|
|
|
|
res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
|
|
|
|
|
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
else:
|
|
|
|
|
await self._exit()
|
|
|
|
|
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
|
|
|
|
|
try:
|
|
|
|
|
if await self.chat_complete.page_index_exists():
|
|
|
|
|
question_tokens = await self.tiktoken.get_tokens(question)
|
|
|
|
|
|
|
|
|
|
extract_limit = embedding_search["limit"] or 10
|
|
|
|
|
|
|
|
|
|
self.transatcion_id: Optional[str] = None
|
|
|
|
|
self.point_cost: int = 0
|
|
|
|
|
if not self.is_system:
|
|
|
|
|
usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete",
|
|
|
|
|
question_tokens, extract_limit)
|
|
|
|
|
self.transatcion_id = usage_res["transaction_id"]
|
|
|
|
|
self.point_cost = usage_res["point_cost"]
|
|
|
|
|
|
|
|
|
|
res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
|
|
|
|
|
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
else:
|
|
|
|
|
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
await self.end()
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
async def _on_message(self, delta_message: str):
|
|
|
|
|
self.chunks.append(delta_message)
|
|
|
|
@ -122,9 +125,9 @@ class ChatCompleteTask:
|
|
|
|
|
|
|
|
|
|
await self._on_error(e)
|
|
|
|
|
finally:
|
|
|
|
|
await self._exit()
|
|
|
|
|
await self.end()
|
|
|
|
|
|
|
|
|
|
async def _exit(self):
|
|
|
|
|
async def end(self):
|
|
|
|
|
await self.chat_complete_service.__aexit__(None, None, None)
|
|
|
|
|
del chat_complete_tasks[self.task_id]
|
|
|
|
|
self.is_finished = True
|
|
|
|
|