from __future__ import annotations import sys import traceback from typing import List, Dict, Callable, Optional from agentkit.types import PluginConfigParam, HookInfo from agentkit.utils.utils import create_plugin_params class AgentKitHooks: HookRegistry: Dict[str, List[HookInfo]] = {} @classmethod def register(cls, hook_name: str, class_: Callable, id: str, name: str = "", description: str = "", param_def: list[PluginConfigParam] = []): if hook_name not in cls.HookRegistry: cls.HookRegistry[hook_name] = [] if name == "": name = id cls.HookRegistry[hook_name].append(HookInfo( id=id, name=name, description=description, class_=class_ )) @classmethod def get(cls, hook_name: str) -> List[HookInfo]: return cls.HookRegistry.get(hook_name, []) @classmethod def instantiate(cls, hook_name: str, hook_item_params: dict, props: Optional[dict] = {}) -> List[Callable]: needle_hook_ids = list(hook_item_params.keys()) instances = [] for hook_info in cls.get(hook_name): if hook_info["id"] in needle_hook_ids: params = hook_item_params[hook_info["id"]] try: class_ = hook_info["class_"] if hasattr(class_, "CONFIG"): params = create_plugin_params(params, class_.CONFIG) instance = class_(props, **params) instances.append(instance) except Exception as e: print(f"Error instantiating hook {hook_info['id']}: {e}", file=sys.stderr) traceback.print_exc() return instances def agentkit_preprocessor(orig_class: Callable, id: str, name: str = "", description: str = "", param_def: list[PluginConfigParam] = []): AgentKitHooks.register("preprocessor", orig_class, id, name, description, param_def) return orig_class def agentkit_postprocessor(orig_class: Callable, id: str, name: str = "", description: str = "", param_def: list[PluginConfigParam] = []): AgentKitHooks.register("postprocessor", orig_class, id, name, description, param_def) return orig_class def agentkit_func_call_handler(orig_class: Callable, id: str, name: str = "", description: str = "", param_def: list[PluginConfigParam] = []): AgentKitHooks.register("func_call_handler", orig_class, id, name, description, param_def) return orig_class