尝试修复资源未关闭的问题

master
落雨楓 2 years ago
parent 99e23f5281
commit 7b4c70147b

@ -50,26 +50,29 @@ class ChatCompleteTask:
self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title) self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
self.chat_complete = await self.chat_complete_service.__aenter__() self.chat_complete = await self.chat_complete_service.__aenter__()
if await self.chat_complete.page_index_exists(): try:
question_tokens = await self.tiktoken.get_tokens(question) if await self.chat_complete.page_index_exists():
question_tokens = await self.tiktoken.get_tokens(question)
extract_limit = embedding_search["limit"] or 10 extract_limit = embedding_search["limit"] or 10
self.transatcion_id: Optional[str] = None self.transatcion_id: Optional[str] = None
self.point_cost: int = 0 self.point_cost: int = 0
if not self.is_system: if not self.is_system:
usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete", usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete",
question_tokens, extract_limit) question_tokens, extract_limit)
self.transatcion_id = usage_res["transaction_id"] self.transatcion_id = usage_res["transaction_id"]
self.point_cost = usage_res["point_cost"] self.point_cost = usage_res["point_cost"]
res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id, res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search) user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
return res return res
else: else:
await self._exit() raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title) except Exception as e:
await self.end()
raise e
async def _on_message(self, delta_message: str): async def _on_message(self, delta_message: str):
self.chunks.append(delta_message) self.chunks.append(delta_message)
@ -122,9 +125,9 @@ class ChatCompleteTask:
await self._on_error(e) await self._on_error(e)
finally: finally:
await self._exit() await self.end()
async def _exit(self): async def end(self):
await self.chat_complete_service.__aexit__(None, None, None) await self.chat_complete_service.__aexit__(None, None, None)
del chat_complete_tasks[self.task_id] del chat_complete_tasks[self.task_id]
self.is_finished = True self.is_finished = True

@ -24,7 +24,6 @@ class BaseHelper:
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
await self.session.__aexit__(exc_type, exc, tb) await self.session.__aexit__(exc_type, exc, tb)
pass
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)

@ -240,12 +240,13 @@ class ChatCompleteService:
message_log = [] message_log = []
if self.conversation_chunk is not None: if self.conversation_chunk is not None:
for message in self.conversation_chunk.message_data: for message in self.conversation_chunk.message_data:
message_log.append( if message["role"] in ["user", "assistant"]:
{ message_log.append(
"role": message["role"], {
"content": message["content"], "role": message["role"],
} "content": message["content"],
) }
)
if self.extract_doc is not None: if self.extract_doc is not None:
doc_prompt_content = "\n".join( doc_prompt_content = "\n".join(

@ -39,8 +39,7 @@ class DatabaseService:
async def init(self): async def init(self):
loop = local.loop loop = local.loop
self.pool = asyncpg.create_pool(**config.DATABASE, loop=loop) self.pool = await asyncpg.create_pool(**config.DATABASE, loop=loop)
await self.pool.__aenter__()
engine = create_async_engine(get_dsn(), echo=config.DEBUG) engine = create_async_engine(get_dsn(), echo=config.DEBUG)
self.engine = engine self.engine = engine

Loading…
Cancel
Save