You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

139 lines
5.0 KiB
Python

from __future__ import annotations
import sys
import traceback
from typing import Optional
from agentkit.errors import AgentKitPluginError
from agentkit.hooks import agentkit_preprocessor
from agentkit.base.preprocessor import AgentKitPreprocessor
from agentkit.types import PluginConfigParam
from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
@agentkit_preprocessor(id="mediawiki_embedding_search",
name="页面内容提取",
description="每次对话前,在页面中提取相关内容")
class EmbeddingSearchPreprocessor(AgentKitPreprocessor):
CONFIG: list[PluginConfigParam] = [
{
"id": "distance_limit",
"name": "向量查找距离限制",
"type": "float",
"required": True,
},
{
"id": "extract_limit",
"name": "提取数量限制",
"type": "int",
"required": True,
"default": 15,
}
]
OUTPUT: list[PluginConfigParam] = [
{
"id": "document",
"name": "抽取的文档",
"type": "string",
"default": "",
}
]
def __init__(self, props: dict, distance_limit: float,
extract_limit: int, **kwargs):
self.props = props
self.distance_limit = distance_limit
self.extract_limit = extract_limit
self.page_title = props.get("page_title", "")
self.db_service = props.get("db_service", None)
if not self.db_service:
raise ValueError("db_service is required in props")
if self.page_title:
self.embedding_search = EmbeddingSearchService(self.db_service, self.page_title)
else:
self.embedding_search = None
async def __aenter__(self):
if self.embedding_search:
self.embedding_search.__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.embedding_search:
self.embedding_search.__aexit__(exc_type, exc_val, exc_tb)
async def on_open_conversation(self, on_progress=None):
if not self.embedding_search:
return
mwapi = MediaWikiApi.create()
try:
if await self.embedding_search.should_update_page_index():
if self.props.get("caller_type") == "user":
user_id = self.props.get("user")
usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id")
await self.embedding_search.prepare_update_index()
token_usage = await self.embedding_search.update_page_index(on_progress)
if transatcion_id:
result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
except MediaWikiPageNotFoundException:
pass
except MediaWikiApiException as e:
error_msg = "MediaWiki API error: %s" % str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
raise AgentKitPluginError(f"MediaWiki API error: ${e.info}", e.code)
except EmbeddingRunningException:
error_msg = "Page index is running now"
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
raise AgentKitPluginError(error_msg, "page-index-running")
except ConnectionResetError:
pass # Ignore websocket close error
except Exception as e:
error_msg = str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
raise AgentKitPluginError(f"Error preparing page index: ${e}", "page-index-error")
async def on_before_completion(self, prompt: str) -> Optional[dict]:
if not self.embedding_search:
return prompt
try:
extracted_docs = await self.embedding_search.search(prompt, self.extract_limit, True, self.distance_limit)
generated_prompt = ""
for doc in extracted_docs:
generated_prompt += f"- [{doc['title']}]({doc['url']})\n"
return {
"document": generated_prompt
}
except EmbeddingRunningException:
return
except Exception as e:
error_msg = str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
raise AgentKitPluginError(f"Error getting related content: ${e}", "related-content-error")