from __future__ import annotations import sys import traceback from typing import Optional from agentkit.errors import AgentKitPluginError from agentkit.hooks import agentkit_preprocessor from agentkit.base.preprocessor import AgentKitPreprocessor from agentkit.types import PluginConfigParam from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException @agentkit_preprocessor(id="mediawiki_embedding_search", name="页面内容提取", description="每次对话前,在页面中提取相关内容") class EmbeddingSearchPreprocessor(AgentKitPreprocessor): CONFIG: list[PluginConfigParam] = [ { "id": "distance_limit", "name": "向量查找距离限制", "type": "float", "required": True, }, { "id": "extract_limit", "name": "提取数量限制", "type": "int", "required": True, "default": 15, } ] OUTPUT: list[PluginConfigParam] = [ { "id": "document", "name": "抽取的文档", "type": "string", "default": "", } ] def __init__(self, props: dict, distance_limit: float, extract_limit: int, **kwargs): self.props = props self.distance_limit = distance_limit self.extract_limit = extract_limit self.page_title = props.get("page_title", "") self.db_service = props.get("db_service", None) if not self.db_service: raise ValueError("db_service is required in props") if self.page_title: self.embedding_search = EmbeddingSearchService(self.db_service, self.page_title) else: self.embedding_search = None async def __aenter__(self): if self.embedding_search: self.embedding_search.__aenter__() return self async def __aexit__(self, exc_type, exc_val, exc_tb): if self.embedding_search: self.embedding_search.__aexit__(exc_type, exc_val, exc_tb) async def on_open_conversation(self, on_progress=None): if not self.embedding_search: return mwapi = MediaWikiApi.create() try: if await self.embedding_search.should_update_page_index(): if self.props.get("caller_type") == "user": user_id = self.props.get("user") usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage") transatcion_id = usage_res.get("transaction_id") await self.embedding_search.prepare_update_index() token_usage = await self.embedding_search.update_page_index(on_progress) if transatcion_id: result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage) except MediaWikiPageNotFoundException: pass except MediaWikiApiException as e: error_msg = "MediaWiki API error: %s" % str(e) print(error_msg, file=sys.stderr) traceback.print_exc() if transatcion_id: await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg) raise AgentKitPluginError(f"MediaWiki API error: ${e.info}", e.code) except EmbeddingRunningException: error_msg = "Page index is running now" if transatcion_id: await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg) raise AgentKitPluginError(error_msg, "page-index-running") except ConnectionResetError: pass # Ignore websocket close error except Exception as e: error_msg = str(e) print(error_msg, file=sys.stderr) traceback.print_exc() if transatcion_id: await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg) raise AgentKitPluginError(f"Error preparing page index: ${e}", "page-index-error") async def on_before_completion(self, prompt: str) -> Optional[dict]: if not self.embedding_search: return prompt try: extracted_docs = await self.embedding_search.search(prompt, self.extract_limit, True, self.distance_limit) generated_prompt = "" for doc in extracted_docs: generated_prompt += f"- [{doc['title']}]({doc['url']})\n" return { "document": generated_prompt } except EmbeddingRunningException: return except Exception as e: error_msg = str(e) print(error_msg, file=sys.stderr) traceback.print_exc() raise AgentKitPluginError(f"Error getting related content: ${e}", "related-content-error")