You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
65 lines
2.6 KiB
Python
65 lines
2.6 KiB
Python
from __future__ import annotations
|
|
import sys
|
|
import traceback
|
|
from typing import List, Dict, Callable, Optional
|
|
from agentkit.types import PluginConfigParam, HookInfo
|
|
from agentkit.utils.utils import create_plugin_params
|
|
|
|
|
|
class AgentKitHooks:
|
|
HookRegistry: Dict[str, List[HookInfo]] = {}
|
|
|
|
@classmethod
|
|
def register(cls, hook_name: str, class_: Callable,
|
|
id: str, name: str = "", description: str = "",
|
|
param_def: list[PluginConfigParam] = []):
|
|
if hook_name not in cls.HookRegistry:
|
|
cls.HookRegistry[hook_name] = []
|
|
|
|
if name == "":
|
|
name = id
|
|
|
|
cls.HookRegistry[hook_name].append(HookInfo(
|
|
id=id,
|
|
name=name,
|
|
description=description,
|
|
class_=class_
|
|
))
|
|
|
|
@classmethod
|
|
def get(cls, hook_name: str) -> List[HookInfo]:
|
|
return cls.HookRegistry.get(hook_name, [])
|
|
|
|
@classmethod
|
|
def instantiate(cls, hook_name: str, hook_item_params: dict, props: Optional[dict] = {}) -> List[Callable]:
|
|
needle_hook_ids = list(hook_item_params.keys())
|
|
instances = []
|
|
for hook_info in cls.get(hook_name):
|
|
if hook_info["id"] in needle_hook_ids:
|
|
params = hook_item_params[hook_info["id"]]
|
|
try:
|
|
class_ = hook_info["class_"]
|
|
if hasattr(class_, "CONFIG"):
|
|
params = create_plugin_params(params, class_.CONFIG)
|
|
instance = class_(props, **params)
|
|
instances.append(instance)
|
|
except Exception as e:
|
|
print(f"Error instantiating hook {hook_info['id']}: {e}", file=sys.stderr)
|
|
traceback.print_exc()
|
|
|
|
return instances
|
|
|
|
def agentkit_preprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
|
|
param_def: list[PluginConfigParam] = []):
|
|
AgentKitHooks.register("preprocessor", orig_class, id, name, description, param_def)
|
|
return orig_class
|
|
|
|
def agentkit_postprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
|
|
param_def: list[PluginConfigParam] = []):
|
|
AgentKitHooks.register("postprocessor", orig_class, id, name, description, param_def)
|
|
return orig_class
|
|
|
|
def agentkit_func_call_handler(orig_class: Callable, id: str, name: str = "", description: str = "",
|
|
param_def: list[PluginConfigParam] = []):
|
|
AgentKitHooks.register("func_call_handler", orig_class, id, name, description, param_def)
|
|
return orig_class |