From 83bb6ad213b3e1e6136ae7a9fe3741325996f469 Mon Sep 17 00:00:00 2001 From: Lex Lim Date: Sun, 1 Dec 2024 09:01:14 +0000 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E5=87=BD=E6=95=B0=E8=B0=83?= =?UTF-8?q?=E7=94=A8=E9=83=A8=E5=88=86=E7=9A=84=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agentkit/base/condition.py | 29 ++++++++++++++++++++++++++++ agentkit/base/llm_function.py | 8 ++++---- agentkit/base/modelapi.py | 5 +++-- agentkit/context.py | 4 ++++ agentkit/errors.py | 24 ++++++++++++++++++++++- agentkit/hooks.py | 11 +++++++---- agentkit/llm_function.py | 34 +++++++++++++++++++++++++++++++++ agentkit/types.py | 6 +++++- agentkit/utils/progress.py | 36 +++++++++++++++++++++++++++++++++++ agentkit/utils/utils.py | 11 +++++++++++ 10 files changed, 156 insertions(+), 12 deletions(-) create mode 100644 agentkit/base/condition.py create mode 100644 agentkit/llm_function.py create mode 100644 agentkit/utils/progress.py create mode 100644 agentkit/utils/utils.py diff --git a/agentkit/base/condition.py b/agentkit/base/condition.py new file mode 100644 index 0000000..13d20fd --- /dev/null +++ b/agentkit/base/condition.py @@ -0,0 +1,29 @@ +from __future__ import annotations +from typing import Any +from agentkit.context import ConversationContext +from agentkit.types import PluginConfigParam + + +class AgentKitCondition: + INPUT: list[PluginConfigParam] = [] + """ + Input parameters for the condition. + """ + + def __init__(self, props: dict): + pass + + + async def __aenter__(self) -> AgentKitCondition: + return self + + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + return None + + + async def evaluate(self, context: ConversationContext, **kwargs) -> bool: + """ + Evaluate the condition. + """ + return False \ No newline at end of file diff --git a/agentkit/base/llm_function.py b/agentkit/base/llm_function.py index 44a5aa3..230f57d 100644 --- a/agentkit/base/llm_function.py +++ b/agentkit/base/llm_function.py @@ -1,7 +1,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any, Optional, Union from agentkit.context import ConversationContext -from agentkit.types import LLMFunctionInfo, PluginConfigParam +from agentkit.types import LLMFunctionInfo, LLMFunctionResponse, PluginConfigParam class AgentKitLLMFunction: @@ -9,7 +9,7 @@ class AgentKitLLMFunction: def __init__(self, props: dict, **kwargs): pass - + async def get_llm_function_info() -> Optional[LLMFunctionInfo]: """ @@ -19,5 +19,5 @@ class AgentKitLLMFunction: return None - async def run(self, context: ConversationContext, params: Any) -> str: + async def run(self, context: ConversationContext, params: Any) -> Union[str, LLMFunctionResponse]: return "Function not implemented" \ No newline at end of file diff --git a/agentkit/base/modelapi.py b/agentkit/base/modelapi.py index e95acb7..6036a09 100644 --- a/agentkit/base/modelapi.py +++ b/agentkit/base/modelapi.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any from agentkit.context import ConversationContext +from agentkit.llm_function import AgentKitLLMFunctionContainer from agentkit.types import PluginConfigParam class AgentKitModelApi: @@ -16,10 +17,10 @@ class AgentKitModelApi: async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: return None - async def chatcomplete(self, context: ConversationContext) -> ConversationContext: + async def chatcomplete(self, context: ConversationContext, llm_functions: AgentKitLLMFunctionContainer) -> ConversationContext: return context - async def chatcomplete_stream(self, context: ConversationContext) -> ConversationContext: + async def chatcomplete_stream(self, context: ConversationContext, llm_functions: AgentKitLLMFunctionContainer) -> ConversationContext: return context async def get_token_count(self, context: ConversationContext) -> int: diff --git a/agentkit/context.py b/agentkit/context.py index b13d97f..585dbaa 100644 --- a/agentkit/context.py +++ b/agentkit/context.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Optional from agentkit.types import BaseMessage +from agentkit.utils.progress import AgentKitTaskProgress class ConversationData: KEYS_IN_STORE = ["history", "store", "user_id", "conversation_id", "conversation_rounds"] @@ -43,6 +44,9 @@ class ConversationContext(ConversationData): self.completion: Optional[str] = None """Completion result""" + self.progress = AgentKitTaskProgress() + """Task progress""" + def get_messages_by_role(self, role: str) -> list[BaseMessage]: return [msg for msg in self.history if msg["role"] == role] diff --git a/agentkit/errors.py b/agentkit/errors.py index e1b8c6a..0db0c8c 100644 --- a/agentkit/errors.py +++ b/agentkit/errors.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Optional class AgentKitPluginError(Exception): @@ -11,4 +12,25 @@ class AgentKitPluginError(Exception): return f"[{self.code}] ${self.message}" def __repr__(self): - return f"AgentPluginError({self.message}, {self.code})" \ No newline at end of file + return f"AgentPluginError({self.message}, {self.code})" + + +class AgentKitNotFoundError(Exception): + def __init__(self, message, resource_type: str = "", resource_id: Optional[str] = None): + self.message = message + self.resource_type = resource_type + self.resource_id = resource_id + + def __str__(self): + err_str = "AgentNotFoundError " + + if self.resource_type and self.resource_id: + err_str += f"[{self.resource_type}/{self.resource_id}]" + elif self.resource_type: + err_str += f"[{self.resource_type}]" + + err_str += f": {self.message}" + return err_str + + def __repr__(self): + return f"AgentNotFoundError({self.message}, {self.resource_type}, {self.resource_id})" \ No newline at end of file diff --git a/agentkit/hooks.py b/agentkit/hooks.py index 90e1625..8f0f456 100644 --- a/agentkit/hooks.py +++ b/agentkit/hooks.py @@ -1,8 +1,9 @@ from __future__ import annotations import sys import traceback -from typing import List, Dict, Callable +from typing import List, Dict, Callable, Optional from agentkit.types import PluginConfigParam, HookInfo +from agentkit.utils.utils import create_plugin_params class AgentKitHooks: @@ -22,7 +23,6 @@ class AgentKitHooks: id=id, name=name, description=description, - params_def=param_def, class_=class_ )) @@ -31,14 +31,17 @@ class AgentKitHooks: return cls.HookRegistry.get(hook_name, []) @classmethod - def instantiate(cls, hook_name: str, hook_item_params: dict) -> List[Callable]: + def instantiate(cls, hook_name: str, hook_item_params: dict, props: Optional[dict] = {}) -> List[Callable]: needle_hook_ids = list(hook_item_params.keys()) instances = [] for hook_info in cls.get(hook_name): if hook_info["id"] in needle_hook_ids: params = hook_item_params[hook_info["id"]] try: - instance = hook_info["class_"](**params) + class_ = hook_info["class_"] + if hasattr(class_, "CONFIG"): + params = create_plugin_params(params, class_.CONFIG) + instance = class_(props, **params) instances.append(instance) except Exception as e: print(f"Error instantiating hook {hook_info['id']}: {e}", file=sys.stderr) diff --git a/agentkit/llm_function.py b/agentkit/llm_function.py new file mode 100644 index 0000000..47cb970 --- /dev/null +++ b/agentkit/llm_function.py @@ -0,0 +1,34 @@ +from typing import Any, Union +from agentkit.base.llm_function import AgentKitLLMFunction +from agentkit.context import ConversationContext +from agentkit.errors import AgentKitNotFoundError +from agentkit.hooks import AgentKitHooks +from agentkit.types import LLMFunctionInfo, LLMFunctionResponse + + +class AgentKitLLMFunctionContainer: + def __init__(self, props: dict, llm_function_configs: dict[str, list]): + self.props = props + self.llm_function_list: list[AgentKitLLMFunction] = AgentKitHooks.instantiate( + "llm_function", llm_function_configs, props + ) + + async def get_llm_function_info(self) -> dict[str, dict]: + llm_function_info: list[LLMFunctionInfo] = [] + for llm_function in self.llm_function_list: + info = await llm_function.get_llm_function_info() + if info: + llm_function_info.append(info) + return llm_function_info + + + async def call_llm_function(self, context: ConversationContext, llm_function_name: str, params: Any) -> Union[str, LLMFunctionResponse]: + for llm_function in self.llm_function_list: + info = await llm_function.get_llm_function_info() + if info and info["name_for_llm"] == llm_function_name: + result = await llm_function.run(context, params) + # Mark the task as complete, if the function does not do it + await context.progress.task_complete() + return result + + return AgentKitNotFoundError("LLM function not found", "llm_function", llm_function_name) \ No newline at end of file diff --git a/agentkit/types.py b/agentkit/types.py index c3ef611..27da6ba 100644 --- a/agentkit/types.py +++ b/agentkit/types.py @@ -49,4 +49,8 @@ class LLMFunctionResponse(TypedDict): """Directly return the response to the user, without further processing.""" -OnProgressCallback = Callable[[int, int], Any] \ No newline at end of file +OnProgressCallback = Callable[[int, int], Any] + +OnTaskStartCallback = Callable[[str, str], Any] +OnTaskProgressCallback = Callable[[int, Optional[int], Optional[str]], Any] +OnTaskCompleteCallback = Callable[[str], Any] \ No newline at end of file diff --git a/agentkit/utils/progress.py b/agentkit/utils/progress.py new file mode 100644 index 0000000..0c18a24 --- /dev/null +++ b/agentkit/utils/progress.py @@ -0,0 +1,36 @@ +import inspect +from typing import Optional +from agentkit.types import OnTaskStartCallback, OnTaskProgressCallback, OnTaskCompleteCallback + +class AgentKitTaskProgress: + def __init__(self): + self.on_task_start_listeners: list[OnTaskStartCallback] = [] + self.on_task_progress_listeners: list[OnTaskProgressCallback] = [] + self.on_task_complete_listeners: list[OnTaskCompleteCallback] = [] + + def add_on_task_start_listener(self, listener: OnTaskStartCallback) -> None: + self.on_task_start_listeners.append(listener) + + def add_on_task_progress_listener(self, listener: OnTaskProgressCallback) -> None: + self.on_task_progress_listeners.append(listener) + + def add_on_task_complete_listener(self, listener: OnTaskCompleteCallback) -> None: + self.on_task_complete_listeners.append(listener) + + async def _emit(self, listeners: list, *args) -> None: + for listener in listeners: + ret = listener(*args) + if inspect.isawaitable(ret): + await ret + + async def task_start(self, task_title: str, task_description: str = "") -> None: + await self._emit(self.on_task_start_listeners, task_title, task_description) + + async def task_progress(self, current: int, total: Optional[int] = None, task_description: Optional[str] = None) -> None: + await self._emit(self.on_task_progress_listeners, current, total, task_description) + + async def task_progress_infinite(self, task_description: Optional[str] = None) -> None: + await self._emit(self.on_task_progress_listeners, 0, 0, task_description) + + async def task_complete(self) -> None: + await self._emit(self.on_task_complete_listeners) \ No newline at end of file diff --git a/agentkit/utils/utils.py b/agentkit/utils/utils.py new file mode 100644 index 0000000..ba45c71 --- /dev/null +++ b/agentkit/utils/utils.py @@ -0,0 +1,11 @@ +from agentkit.types import PluginConfigParam + + +def create_plugin_params(input_params: dict, param_def: list[PluginConfigParam]) -> dict: + params = {} + for param in param_def: + if param["id"] in input_params: + params[param["id"]] = input_params[param["id"]] + else: + params[param["id"]] = param["default"] + return params \ No newline at end of file