You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
34 lines
1.5 KiB
Python
34 lines
1.5 KiB
Python
from typing import Any, Union
|
|
from agentkit.base.llm_function import AgentKitLLMFunction
|
|
from agentkit.context import ConversationContext
|
|
from agentkit.errors import AgentKitNotFoundError
|
|
from agentkit.hooks import AgentKitHooks
|
|
from agentkit.types import LLMFunctionInfo, LLMFunctionResponse
|
|
|
|
|
|
class AgentKitLLMFunctionContainer:
|
|
def __init__(self, props: dict, llm_function_configs: dict[str, list]):
|
|
self.props = props
|
|
self.llm_function_list: list[AgentKitLLMFunction] = AgentKitHooks.instantiate(
|
|
"llm_function", llm_function_configs, props
|
|
)
|
|
|
|
async def get_llm_function_info(self) -> dict[str, dict]:
|
|
llm_function_info: list[LLMFunctionInfo] = []
|
|
for llm_function in self.llm_function_list:
|
|
info = await llm_function.get_llm_function_info()
|
|
if info:
|
|
llm_function_info.append(info)
|
|
return llm_function_info
|
|
|
|
|
|
async def call_llm_function(self, context: ConversationContext, llm_function_name: str, params: Any) -> Union[str, LLMFunctionResponse]:
|
|
for llm_function in self.llm_function_list:
|
|
info = await llm_function.get_llm_function_info()
|
|
if info and info["name_for_llm"] == llm_function_name:
|
|
result = await llm_function.run(context, params)
|
|
# Mark the task as complete, if the function does not do it
|
|
await context.progress.task_complete()
|
|
return result
|
|
|
|
return AgentKitNotFoundError("LLM function not found", "llm_function", llm_function_name) |