You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
62 lines
2.4 KiB
Python
62 lines
2.4 KiB
Python
2 months ago
|
from __future__ import annotations
|
||
|
import sys
|
||
|
import traceback
|
||
|
from typing import List, Dict, Callable
|
||
|
from agentkit.types import PluginConfigParam, HookInfo
|
||
|
|
||
|
|
||
|
class AgentKitHooks:
|
||
|
HookRegistry: Dict[str, List[HookInfo]] = {}
|
||
|
|
||
|
@classmethod
|
||
|
def register(cls, hook_name: str, class_: Callable,
|
||
|
id: str, name: str = "", description: str = "",
|
||
|
param_def: list[PluginConfigParam] = []):
|
||
|
if hook_name not in cls.HookRegistry:
|
||
|
cls.HookRegistry[hook_name] = []
|
||
|
|
||
|
if name == "":
|
||
|
name = id
|
||
|
|
||
|
cls.HookRegistry[hook_name].append(HookInfo(
|
||
|
id=id,
|
||
|
name=name,
|
||
|
description=description,
|
||
|
params_def=param_def,
|
||
|
class_=class_
|
||
|
))
|
||
|
|
||
|
@classmethod
|
||
|
def get(cls, hook_name: str) -> List[HookInfo]:
|
||
|
return cls.HookRegistry.get(hook_name, [])
|
||
|
|
||
|
@classmethod
|
||
|
def instantiate(cls, hook_name: str, hook_item_params: dict) -> List[Callable]:
|
||
|
needle_hook_ids = list(hook_item_params.keys())
|
||
|
instances = []
|
||
|
for hook_info in cls.get(hook_name):
|
||
|
if hook_info["id"] in needle_hook_ids:
|
||
|
params = hook_item_params[hook_info["id"]]
|
||
|
try:
|
||
|
instance = hook_info["class_"](**params)
|
||
|
instances.append(instance)
|
||
|
except Exception as e:
|
||
|
print(f"Error instantiating hook {hook_info['id']}: {e}", file=sys.stderr)
|
||
|
traceback.print_exc()
|
||
|
|
||
|
return instances
|
||
|
|
||
|
def agentkit_preprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
|
||
|
param_def: list[PluginConfigParam] = []):
|
||
|
AgentKitHooks.register("preprocessor", orig_class, id, name, description, param_def)
|
||
|
return orig_class
|
||
|
|
||
|
def agentkit_postprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
|
||
|
param_def: list[PluginConfigParam] = []):
|
||
|
AgentKitHooks.register("postprocessor", orig_class, id, name, description, param_def)
|
||
|
return orig_class
|
||
|
|
||
|
def agentkit_func_call_handler(orig_class: Callable, id: str, name: str = "", description: str = "",
|
||
|
param_def: list[PluginConfigParam] = []):
|
||
|
AgentKitHooks.register("func_call_handler", orig_class, id, name, description, param_def)
|
||
|
return orig_class
|