From 7b4c70147be0fdf123a5fe4c295d95b7126af91d Mon Sep 17 00:00:00 2001 From: Lex Lim Date: Wed, 21 Jun 2023 08:43:27 +0000 Subject: [PATCH] =?UTF-8?q?=E5=B0=9D=E8=AF=95=E4=BF=AE=E5=A4=8D=E8=B5=84?= =?UTF-8?q?=E6=BA=90=E6=9C=AA=E5=85=B3=E9=97=AD=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/controller/task/ChatCompleteTask.py | 47 +++++++++++++------------ api/model/base.py | 1 - service/chat_complete.py | 13 +++---- service/database.py | 3 +- 4 files changed, 33 insertions(+), 31 deletions(-) diff --git a/api/controller/task/ChatCompleteTask.py b/api/controller/task/ChatCompleteTask.py index 0952414..b1ab227 100644 --- a/api/controller/task/ChatCompleteTask.py +++ b/api/controller/task/ChatCompleteTask.py @@ -50,26 +50,29 @@ class ChatCompleteTask: self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title) self.chat_complete = await self.chat_complete_service.__aenter__() - if await self.chat_complete.page_index_exists(): - question_tokens = await self.tiktoken.get_tokens(question) - - extract_limit = embedding_search["limit"] or 10 - - self.transatcion_id: Optional[str] = None - self.point_cost: int = 0 - if not self.is_system: - usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete", - question_tokens, extract_limit) - self.transatcion_id = usage_res["transaction_id"] - self.point_cost = usage_res["point_cost"] - - res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id, - user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search) - - return res - else: - await self._exit() - raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title) + try: + if await self.chat_complete.page_index_exists(): + question_tokens = await self.tiktoken.get_tokens(question) + + extract_limit = embedding_search["limit"] or 10 + + self.transatcion_id: Optional[str] = None + self.point_cost: int = 0 + if not self.is_system: + usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete", + question_tokens, extract_limit) + self.transatcion_id = usage_res["transaction_id"] + self.point_cost = usage_res["point_cost"] + + res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id, + user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search) + + return res + else: + raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title) + except Exception as e: + await self.end() + raise e async def _on_message(self, delta_message: str): self.chunks.append(delta_message) @@ -122,9 +125,9 @@ class ChatCompleteTask: await self._on_error(e) finally: - await self._exit() + await self.end() - async def _exit(self): + async def end(self): await self.chat_complete_service.__aexit__(None, None, None) del chat_complete_tasks[self.task_id] self.is_finished = True diff --git a/api/model/base.py b/api/model/base.py index e3fccac..f4fcd96 100644 --- a/api/model/base.py +++ b/api/model/base.py @@ -24,7 +24,6 @@ class BaseHelper: async def __aexit__(self, exc_type, exc, tb): await self.session.__aexit__(exc_type, exc, tb) - pass T = TypeVar("T", bound=BaseModel) diff --git a/service/chat_complete.py b/service/chat_complete.py index 541e19e..3dc12e0 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -240,12 +240,13 @@ class ChatCompleteService: message_log = [] if self.conversation_chunk is not None: for message in self.conversation_chunk.message_data: - message_log.append( - { - "role": message["role"], - "content": message["content"], - } - ) + if message["role"] in ["user", "assistant"]: + message_log.append( + { + "role": message["role"], + "content": message["content"], + } + ) if self.extract_doc is not None: doc_prompt_content = "\n".join( diff --git a/service/database.py b/service/database.py index 55b914f..69c623d 100644 --- a/service/database.py +++ b/service/database.py @@ -39,8 +39,7 @@ class DatabaseService: async def init(self): loop = local.loop - self.pool = asyncpg.create_pool(**config.DATABASE, loop=loop) - await self.pool.__aenter__() + self.pool = await asyncpg.create_pool(**config.DATABASE, loop=loop) engine = create_async_engine(get_dsn(), echo=config.DEBUG) self.engine = engine