You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
2.6 KiB
Python

from __future__ import annotations
import sys
import traceback
from typing import List, Dict, Callable, Optional
from agentkit.types import PluginConfigParam, HookInfo
from agentkit.utils.utils import create_plugin_params
class AgentKitHooks:
HookRegistry: Dict[str, List[HookInfo]] = {}
@classmethod
def register(cls, hook_name: str, class_: Callable,
id: str, name: str = "", description: str = "",
param_def: list[PluginConfigParam] = []):
if hook_name not in cls.HookRegistry:
cls.HookRegistry[hook_name] = []
if name == "":
name = id
cls.HookRegistry[hook_name].append(HookInfo(
id=id,
name=name,
description=description,
class_=class_
))
@classmethod
def get(cls, hook_name: str) -> List[HookInfo]:
return cls.HookRegistry.get(hook_name, [])
@classmethod
def instantiate(cls, hook_name: str, hook_item_params: dict, props: Optional[dict] = {}) -> List[Callable]:
needle_hook_ids = list(hook_item_params.keys())
instances = []
for hook_info in cls.get(hook_name):
if hook_info["id"] in needle_hook_ids:
params = hook_item_params[hook_info["id"]]
try:
class_ = hook_info["class_"]
if hasattr(class_, "CONFIG"):
params = create_plugin_params(params, class_.CONFIG)
instance = class_(props, **params)
instances.append(instance)
except Exception as e:
print(f"Error instantiating hook {hook_info['id']}: {e}", file=sys.stderr)
traceback.print_exc()
return instances
def agentkit_preprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
param_def: list[PluginConfigParam] = []):
AgentKitHooks.register("preprocessor", orig_class, id, name, description, param_def)
return orig_class
def agentkit_postprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
param_def: list[PluginConfigParam] = []):
AgentKitHooks.register("postprocessor", orig_class, id, name, description, param_def)
return orig_class
def agentkit_func_call_handler(orig_class: Callable, id: str, name: str = "", description: str = "",
param_def: list[PluginConfigParam] = []):
AgentKitHooks.register("func_call_handler", orig_class, id, name, description, param_def)
return orig_class