You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
139 lines
5.0 KiB
Python
139 lines
5.0 KiB
Python
2 months ago
|
from __future__ import annotations
|
||
|
import sys
|
||
|
import traceback
|
||
|
from typing import Optional
|
||
|
from agentkit.errors import AgentKitPluginError
|
||
|
from agentkit.hooks import agentkit_preprocessor
|
||
|
from agentkit.base.preprocessor import AgentKitPreprocessor
|
||
|
from agentkit.types import PluginConfigParam
|
||
|
from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService
|
||
|
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
|
||
|
|
||
|
@agentkit_preprocessor(id="mediawiki_embedding_search",
|
||
|
name="页面内容提取",
|
||
|
description="每次对话前,在页面中提取相关内容")
|
||
|
class EmbeddingSearchPreprocessor(AgentKitPreprocessor):
|
||
|
CONFIG: list[PluginConfigParam] = [
|
||
|
{
|
||
|
"id": "distance_limit",
|
||
|
"name": "向量查找距离限制",
|
||
|
"type": "float",
|
||
|
"required": True,
|
||
|
},
|
||
|
{
|
||
|
"id": "extract_limit",
|
||
|
"name": "提取数量限制",
|
||
|
"type": "int",
|
||
|
"required": True,
|
||
|
"default": 15,
|
||
|
}
|
||
|
]
|
||
|
|
||
|
OUTPUT: list[PluginConfigParam] = [
|
||
|
{
|
||
|
"id": "document",
|
||
|
"name": "抽取的文档",
|
||
|
"type": "string",
|
||
|
"default": "",
|
||
|
}
|
||
|
]
|
||
|
|
||
|
def __init__(self, props: dict, distance_limit: float,
|
||
|
extract_limit: int, **kwargs):
|
||
|
self.props = props
|
||
|
|
||
|
self.distance_limit = distance_limit
|
||
|
self.extract_limit = extract_limit
|
||
|
|
||
|
self.page_title = props.get("page_title", "")
|
||
|
self.db_service = props.get("db_service", None)
|
||
|
|
||
|
if not self.db_service:
|
||
|
raise ValueError("db_service is required in props")
|
||
|
|
||
|
if self.page_title:
|
||
|
self.embedding_search = EmbeddingSearchService(self.db_service, self.page_title)
|
||
|
else:
|
||
|
self.embedding_search = None
|
||
|
|
||
|
|
||
|
async def __aenter__(self):
|
||
|
if self.embedding_search:
|
||
|
self.embedding_search.__aenter__()
|
||
|
return self
|
||
|
|
||
|
|
||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
|
if self.embedding_search:
|
||
|
self.embedding_search.__aexit__(exc_type, exc_val, exc_tb)
|
||
|
|
||
|
|
||
|
async def on_open_conversation(self, on_progress=None):
|
||
|
if not self.embedding_search:
|
||
|
return
|
||
|
|
||
|
mwapi = MediaWikiApi.create()
|
||
|
try:
|
||
|
if await self.embedding_search.should_update_page_index():
|
||
|
if self.props.get("caller_type") == "user":
|
||
|
user_id = self.props.get("user")
|
||
|
usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
|
||
|
transatcion_id = usage_res.get("transaction_id")
|
||
|
|
||
|
await self.embedding_search.prepare_update_index()
|
||
|
|
||
|
token_usage = await self.embedding_search.update_page_index(on_progress)
|
||
|
|
||
|
if transatcion_id:
|
||
|
result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
|
||
|
except MediaWikiPageNotFoundException:
|
||
|
pass
|
||
|
except MediaWikiApiException as e:
|
||
|
error_msg = "MediaWiki API error: %s" % str(e)
|
||
|
print(error_msg, file=sys.stderr)
|
||
|
traceback.print_exc()
|
||
|
|
||
|
if transatcion_id:
|
||
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
||
|
|
||
|
raise AgentKitPluginError(f"MediaWiki API error: ${e.info}", e.code)
|
||
|
except EmbeddingRunningException:
|
||
|
error_msg = "Page index is running now"
|
||
|
|
||
|
if transatcion_id:
|
||
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
||
|
|
||
|
raise AgentKitPluginError(error_msg, "page-index-running")
|
||
|
except ConnectionResetError:
|
||
|
pass # Ignore websocket close error
|
||
|
except Exception as e:
|
||
|
error_msg = str(e)
|
||
|
print(error_msg, file=sys.stderr)
|
||
|
traceback.print_exc()
|
||
|
|
||
|
if transatcion_id:
|
||
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
||
|
|
||
|
raise AgentKitPluginError(f"Error preparing page index: ${e}", "page-index-error")
|
||
|
|
||
|
|
||
|
async def on_before_completion(self, prompt: str) -> Optional[dict]:
|
||
|
if not self.embedding_search:
|
||
|
return prompt
|
||
|
|
||
|
try:
|
||
|
extracted_docs = await self.embedding_search.search(prompt, self.extract_limit, True, self.distance_limit)
|
||
|
generated_prompt = ""
|
||
|
for doc in extracted_docs:
|
||
|
generated_prompt += f"- [{doc['title']}]({doc['url']})\n"
|
||
|
|
||
|
return {
|
||
|
"document": generated_prompt
|
||
|
}
|
||
|
except EmbeddingRunningException:
|
||
|
return
|
||
|
except Exception as e:
|
||
|
error_msg = str(e)
|
||
|
print(error_msg, file=sys.stderr)
|
||
|
traceback.print_exc()
|
||
|
raise AgentKitPluginError(f"Error getting related content: ${e}", "related-content-error")
|