From e5a643afd2d91aed0f2dadd6c8419311b1083eb4 Mon Sep 17 00:00:00 2001 From: Lex Lim Date: Fri, 29 Nov 2024 14:07:55 +0000 Subject: [PATCH] =?UTF-8?q?=E7=BC=96=E5=86=99AgentKit=E6=A1=86=E6=9E=B6?= =?UTF-8?q?=EF=BC=8C=E5=B0=86AI=E5=AF=B9=E8=AF=9D=E5=88=86=E7=A6=BB?= =?UTF-8?q?=E5=88=B0AgentKit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agentkit/agentkit.py | 6 + agentkit/base/llm_function.py | 23 +++ agentkit/base/modelapi.py | 26 ++++ agentkit/base/postprocessor.py | 35 +++++ agentkit/base/preprocessor.py | 26 ++++ agentkit/context.py | 86 +++++++++++ agentkit/errors.py | 14 ++ agentkit/hooks.py | 62 ++++++++ .../preprocessor/mw_embedding_search.py | 139 ++++++++++++++++++ agentkit/types.py | 52 +++++++ server/model/embedding_search/page_index.py | 5 +- service/chat_complete.py | 4 +- 12 files changed, 473 insertions(+), 5 deletions(-) create mode 100644 agentkit/agentkit.py create mode 100644 agentkit/base/llm_function.py create mode 100644 agentkit/base/modelapi.py create mode 100644 agentkit/base/postprocessor.py create mode 100644 agentkit/base/preprocessor.py create mode 100644 agentkit/context.py create mode 100644 agentkit/errors.py create mode 100644 agentkit/hooks.py create mode 100644 agentkit/plugin/preprocessor/mw_embedding_search.py create mode 100644 agentkit/types.py diff --git a/agentkit/agentkit.py b/agentkit/agentkit.py new file mode 100644 index 0000000..d8b2f50 --- /dev/null +++ b/agentkit/agentkit.py @@ -0,0 +1,6 @@ +from __future__ import annotations + + +class AgentKit: + def get_agent(self, agent_id: str): + pass \ No newline at end of file diff --git a/agentkit/base/llm_function.py b/agentkit/base/llm_function.py new file mode 100644 index 0000000..44a5aa3 --- /dev/null +++ b/agentkit/base/llm_function.py @@ -0,0 +1,23 @@ +from __future__ import annotations +from typing import Any, Optional +from agentkit.context import ConversationContext +from agentkit.types import LLMFunctionInfo, PluginConfigParam + + +class AgentKitLLMFunction: + CONFIG: list[PluginConfigParam] = [] + + def __init__(self, props: dict, **kwargs): + pass + + + async def get_llm_function_info() -> Optional[LLMFunctionInfo]: + """ + Get function information for the LLM function. + return None to disable the function for this context. + """ + return None + + + async def run(self, context: ConversationContext, params: Any) -> str: + return "Function not implemented" \ No newline at end of file diff --git a/agentkit/base/modelapi.py b/agentkit/base/modelapi.py new file mode 100644 index 0000000..e95acb7 --- /dev/null +++ b/agentkit/base/modelapi.py @@ -0,0 +1,26 @@ +from __future__ import annotations +from typing import Any + +from agentkit.context import ConversationContext +from agentkit.types import PluginConfigParam + +class AgentKitModelApi: + CONFIG: list[PluginConfigParam] = [] + + def __init__(self, props: dict, **kwargs): + pass + + async def __aenter__(self) -> AgentKitModelApi: + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + return None + + async def chatcomplete(self, context: ConversationContext) -> ConversationContext: + return context + + async def chatcomplete_stream(self, context: ConversationContext) -> ConversationContext: + return context + + async def get_token_count(self, context: ConversationContext) -> int: + return 0 \ No newline at end of file diff --git a/agentkit/base/postprocessor.py b/agentkit/base/postprocessor.py new file mode 100644 index 0000000..3faf75f --- /dev/null +++ b/agentkit/base/postprocessor.py @@ -0,0 +1,35 @@ +from __future__ import annotations +from typing import Optional +from agentkit.context import ConversationContext +from agentkit.types import PluginConfigParam, OnProgressCallback + +class AgentKitPostprocessor: + CONFIG: list[PluginConfigParam] = [] + OUTPUT: list[PluginConfigParam] = [] + + def __init__(self, props: dict, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + async def on_create_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None: + return None + + async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None: + return None + + async def on_after_completion(self, context: ConversationContext) -> Optional[dict]: + """ + Called after the conversation, before the completion is returned to the user. + """ + return None + + async def on_after_completion_background(self, context: ConversationContext) -> Optional[dict]: + """ + Called after the conversation, after the completion is returned to the user. + """ + return None \ No newline at end of file diff --git a/agentkit/base/preprocessor.py b/agentkit/base/preprocessor.py new file mode 100644 index 0000000..571761c --- /dev/null +++ b/agentkit/base/preprocessor.py @@ -0,0 +1,26 @@ +from __future__ import annotations +from typing import Optional +from agentkit.context import ConversationContext +from agentkit.types import PluginConfigParam, OnProgressCallback + +class AgentKitPreprocessor: + CONFIG: list[PluginConfigParam] = [] + OUTPUT: list[PluginConfigParam] = [] + + def __init__(self, props: dict, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + async def on_create_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None: + return None + + async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None: + return None + + async def on_before_completion(self, context: ConversationContext) -> Optional[dict]: + return None \ No newline at end of file diff --git a/agentkit/context.py b/agentkit/context.py new file mode 100644 index 0000000..b13d97f --- /dev/null +++ b/agentkit/context.py @@ -0,0 +1,86 @@ +from __future__ import annotations +from typing import Optional +from agentkit.types import BaseMessage + +class ConversationData: + KEYS_IN_STORE = ["history", "store", "user_id", "conversation_id", "conversation_rounds"] + + def __init__(self): + # Persistent properties + self.history: list[BaseMessage] = [] + """Chat history""" + + self.store: dict = {} + """Store for plugins""" + + self.user_id: Optional[str] = None + """Current user ID""" + + self.conversation_id: Optional[str] = None + """Current conversation ID""" + + self.conversation_rounds: int = 0 + """Number of conversation rounds""" + + def to_dict(self): + return {key: getattr(self, key) for key in self.KEYS_IN_STORE} + + +class ConversationContext(ConversationData): + def __init__(self): + super().__init__() + + # Temporary properties + self.prompt = "" + """User prompt""" + + self.system_prompt = "" + """System prompt""" + + self.session_data: dict = {} + """Temporary data for plugins""" + + self.completion: Optional[str] = None + """Completion result""" + + + def get_messages_by_role(self, role: str) -> list[BaseMessage]: + return [msg for msg in self.history if msg["role"] == role] + + + def create_sub_context(self, copy_store = False, copy_session_data = False) -> ConversationContext: + sub_context = ConversationContext() + sub_context.user_id = self.user_id + + if copy_store: + sub_context.store = self.store.copy() + + if copy_session_data: + sub_context.session_data = self.session_data.copy() + + return sub_context + + + def __getitem__(self, key): + return self.session_data[key] + + + def __setitem__(self, key, value): + self.session_data[key] = value + return value + + + def __delitem__(self, key): + del self.session_data[key] + + + def __contains__(self, key): + return key in self.session_data + + + def __iter__(self): + return iter(self.session_data) + + + def __len__(self): + return len(self.session_data) \ No newline at end of file diff --git a/agentkit/errors.py b/agentkit/errors.py new file mode 100644 index 0000000..e1b8c6a --- /dev/null +++ b/agentkit/errors.py @@ -0,0 +1,14 @@ +from __future__ import annotations + + +class AgentKitPluginError(Exception): + """Base class for exceptions in the agent plugin module.""" + def __init__(self, message, code=None): + self.message = message + self.code = code + + def __str__(self): + return f"[{self.code}] ${self.message}" + + def __repr__(self): + return f"AgentPluginError({self.message}, {self.code})" \ No newline at end of file diff --git a/agentkit/hooks.py b/agentkit/hooks.py new file mode 100644 index 0000000..90e1625 --- /dev/null +++ b/agentkit/hooks.py @@ -0,0 +1,62 @@ +from __future__ import annotations +import sys +import traceback +from typing import List, Dict, Callable +from agentkit.types import PluginConfigParam, HookInfo + + +class AgentKitHooks: + HookRegistry: Dict[str, List[HookInfo]] = {} + + @classmethod + def register(cls, hook_name: str, class_: Callable, + id: str, name: str = "", description: str = "", + param_def: list[PluginConfigParam] = []): + if hook_name not in cls.HookRegistry: + cls.HookRegistry[hook_name] = [] + + if name == "": + name = id + + cls.HookRegistry[hook_name].append(HookInfo( + id=id, + name=name, + description=description, + params_def=param_def, + class_=class_ + )) + + @classmethod + def get(cls, hook_name: str) -> List[HookInfo]: + return cls.HookRegistry.get(hook_name, []) + + @classmethod + def instantiate(cls, hook_name: str, hook_item_params: dict) -> List[Callable]: + needle_hook_ids = list(hook_item_params.keys()) + instances = [] + for hook_info in cls.get(hook_name): + if hook_info["id"] in needle_hook_ids: + params = hook_item_params[hook_info["id"]] + try: + instance = hook_info["class_"](**params) + instances.append(instance) + except Exception as e: + print(f"Error instantiating hook {hook_info['id']}: {e}", file=sys.stderr) + traceback.print_exc() + + return instances + +def agentkit_preprocessor(orig_class: Callable, id: str, name: str = "", description: str = "", + param_def: list[PluginConfigParam] = []): + AgentKitHooks.register("preprocessor", orig_class, id, name, description, param_def) + return orig_class + +def agentkit_postprocessor(orig_class: Callable, id: str, name: str = "", description: str = "", + param_def: list[PluginConfigParam] = []): + AgentKitHooks.register("postprocessor", orig_class, id, name, description, param_def) + return orig_class + +def agentkit_func_call_handler(orig_class: Callable, id: str, name: str = "", description: str = "", + param_def: list[PluginConfigParam] = []): + AgentKitHooks.register("func_call_handler", orig_class, id, name, description, param_def) + return orig_class \ No newline at end of file diff --git a/agentkit/plugin/preprocessor/mw_embedding_search.py b/agentkit/plugin/preprocessor/mw_embedding_search.py new file mode 100644 index 0000000..5287c73 --- /dev/null +++ b/agentkit/plugin/preprocessor/mw_embedding_search.py @@ -0,0 +1,139 @@ +from __future__ import annotations +import sys +import traceback +from typing import Optional +from agentkit.errors import AgentKitPluginError +from agentkit.hooks import agentkit_preprocessor +from agentkit.base.preprocessor import AgentKitPreprocessor +from agentkit.types import PluginConfigParam +from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService +from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException + +@agentkit_preprocessor(id="mediawiki_embedding_search", + name="页面内容提取", + description="每次对话前,在页面中提取相关内容") +class EmbeddingSearchPreprocessor(AgentKitPreprocessor): + CONFIG: list[PluginConfigParam] = [ + { + "id": "distance_limit", + "name": "向量查找距离限制", + "type": "float", + "required": True, + }, + { + "id": "extract_limit", + "name": "提取数量限制", + "type": "int", + "required": True, + "default": 15, + } + ] + + OUTPUT: list[PluginConfigParam] = [ + { + "id": "document", + "name": "抽取的文档", + "type": "string", + "default": "", + } + ] + + def __init__(self, props: dict, distance_limit: float, + extract_limit: int, **kwargs): + self.props = props + + self.distance_limit = distance_limit + self.extract_limit = extract_limit + + self.page_title = props.get("page_title", "") + self.db_service = props.get("db_service", None) + + if not self.db_service: + raise ValueError("db_service is required in props") + + if self.page_title: + self.embedding_search = EmbeddingSearchService(self.db_service, self.page_title) + else: + self.embedding_search = None + + + async def __aenter__(self): + if self.embedding_search: + self.embedding_search.__aenter__() + return self + + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.embedding_search: + self.embedding_search.__aexit__(exc_type, exc_val, exc_tb) + + + async def on_open_conversation(self, on_progress=None): + if not self.embedding_search: + return + + mwapi = MediaWikiApi.create() + try: + if await self.embedding_search.should_update_page_index(): + if self.props.get("caller_type") == "user": + user_id = self.props.get("user") + usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage") + transatcion_id = usage_res.get("transaction_id") + + await self.embedding_search.prepare_update_index() + + token_usage = await self.embedding_search.update_page_index(on_progress) + + if transatcion_id: + result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage) + except MediaWikiPageNotFoundException: + pass + except MediaWikiApiException as e: + error_msg = "MediaWiki API error: %s" % str(e) + print(error_msg, file=sys.stderr) + traceback.print_exc() + + if transatcion_id: + await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg) + + raise AgentKitPluginError(f"MediaWiki API error: ${e.info}", e.code) + except EmbeddingRunningException: + error_msg = "Page index is running now" + + if transatcion_id: + await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg) + + raise AgentKitPluginError(error_msg, "page-index-running") + except ConnectionResetError: + pass # Ignore websocket close error + except Exception as e: + error_msg = str(e) + print(error_msg, file=sys.stderr) + traceback.print_exc() + + if transatcion_id: + await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg) + + raise AgentKitPluginError(f"Error preparing page index: ${e}", "page-index-error") + + + async def on_before_completion(self, prompt: str) -> Optional[dict]: + if not self.embedding_search: + return prompt + + try: + extracted_docs = await self.embedding_search.search(prompt, self.extract_limit, True, self.distance_limit) + generated_prompt = "" + for doc in extracted_docs: + generated_prompt += f"- [{doc['title']}]({doc['url']})\n" + + return { + "document": generated_prompt + } + except EmbeddingRunningException: + return + except Exception as e: + error_msg = str(e) + print(error_msg, file=sys.stderr) + traceback.print_exc() + raise AgentKitPluginError(f"Error getting related content: ${e}", "related-content-error") \ No newline at end of file diff --git a/agentkit/types.py b/agentkit/types.py new file mode 100644 index 0000000..c3ef611 --- /dev/null +++ b/agentkit/types.py @@ -0,0 +1,52 @@ +from __future__ import annotations +from typing import Any, Callable, List, Optional, TypedDict + + +class ConfigSelectItem(TypedDict): + label: str + value: Any + + +class PluginConfigParam(TypedDict): + id: str + type: str + name: str + description: str + default: Any + required: bool + min: Optional[float] + max: Optional[float] + options: Optional[list[ConfigSelectItem]] + + +class HookInfo(TypedDict): + id: str + name: str + description: str + class_: Callable + + +class BaseMessage(TypedDict): + role: str + content: str + function_call: Optional[Any] + tool_calls: Optional[list] + tool_call_id: Optional[str] + + +class LLMFunctionInfo(TypedDict): + name_for_llm: str + name_for_human: str + description: str + params: dict + + +class LLMFunctionResponse(TypedDict): + response: str + """The response text generated by the function.""" + + direct_return: bool + """Directly return the response to the user, without further processing.""" + + +OnProgressCallback = Callable[[int, int], Any] \ No newline at end of file diff --git a/server/model/embedding_search/page_index.py b/server/model/embedding_search/page_index.py index 5862bdd..2dfd6c3 100644 --- a/server/model/embedding_search/page_index.py +++ b/server/model/embedding_search/page_index.py @@ -127,9 +127,8 @@ class PageIndexHelper: self.table_initialized = True async def create_embedding_index(self): - pass - # await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);" - # .replace("/*_*/", self.table_name)) + await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);" + .replace("/*_*/", self.table_name)) def sha1_doc(self, doc: list): for item in doc: diff --git a/service/chat_complete.py b/service/chat_complete.py index 12b47b2..9f1a401 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -409,7 +409,7 @@ class ChatCompleteService: chat_log: list[str] = [] bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str) api_id = Config.get("chatcomplete.system_api_id", "default", str) - model_id = Config.get("chatcomplete.system_api_id", "default", str) + model_id = Config.get("chatcomplete.system_model_id", "default", str) openai_api = OpenAIApi.create(api_id) @@ -442,7 +442,7 @@ class ChatCompleteService: chat_log: list[str] = [] bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str) api_id = Config.get("chatcomplete.system_api_id", "default", str) - model_id = Config.get("chatcomplete.system_api_id", "default", str) + model_id = Config.get("chatcomplete.system_model_id", "default", str) openai_api = OpenAIApi.create(api_id)