编写AgentKit框架,将AI对话分离到AgentKit
parent
071ab94829
commit
e5a643afd2
@ -0,0 +1,6 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
class AgentKit:
|
||||||
|
def get_agent(self, agent_id: str):
|
||||||
|
pass
|
@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Any, Optional
|
||||||
|
from agentkit.context import ConversationContext
|
||||||
|
from agentkit.types import LLMFunctionInfo, PluginConfigParam
|
||||||
|
|
||||||
|
|
||||||
|
class AgentKitLLMFunction:
|
||||||
|
CONFIG: list[PluginConfigParam] = []
|
||||||
|
|
||||||
|
def __init__(self, props: dict, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def get_llm_function_info() -> Optional[LLMFunctionInfo]:
|
||||||
|
"""
|
||||||
|
Get function information for the LLM function.
|
||||||
|
return None to disable the function for this context.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def run(self, context: ConversationContext, params: Any) -> str:
|
||||||
|
return "Function not implemented"
|
@ -0,0 +1,26 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.context import ConversationContext
|
||||||
|
from agentkit.types import PluginConfigParam
|
||||||
|
|
||||||
|
class AgentKitModelApi:
|
||||||
|
CONFIG: list[PluginConfigParam] = []
|
||||||
|
|
||||||
|
def __init__(self, props: dict, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self) -> AgentKitModelApi:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def chatcomplete(self, context: ConversationContext) -> ConversationContext:
|
||||||
|
return context
|
||||||
|
|
||||||
|
async def chatcomplete_stream(self, context: ConversationContext) -> ConversationContext:
|
||||||
|
return context
|
||||||
|
|
||||||
|
async def get_token_count(self, context: ConversationContext) -> int:
|
||||||
|
return 0
|
@ -0,0 +1,35 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Optional
|
||||||
|
from agentkit.context import ConversationContext
|
||||||
|
from agentkit.types import PluginConfigParam, OnProgressCallback
|
||||||
|
|
||||||
|
class AgentKitPostprocessor:
|
||||||
|
CONFIG: list[PluginConfigParam] = []
|
||||||
|
OUTPUT: list[PluginConfigParam] = []
|
||||||
|
|
||||||
|
def __init__(self, props: dict, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def on_create_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def on_after_completion(self, context: ConversationContext) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Called after the conversation, before the completion is returned to the user.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def on_after_completion_background(self, context: ConversationContext) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Called after the conversation, after the completion is returned to the user.
|
||||||
|
"""
|
||||||
|
return None
|
@ -0,0 +1,26 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Optional
|
||||||
|
from agentkit.context import ConversationContext
|
||||||
|
from agentkit.types import PluginConfigParam, OnProgressCallback
|
||||||
|
|
||||||
|
class AgentKitPreprocessor:
|
||||||
|
CONFIG: list[PluginConfigParam] = []
|
||||||
|
OUTPUT: list[PluginConfigParam] = []
|
||||||
|
|
||||||
|
def __init__(self, props: dict, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def on_create_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def on_before_completion(self, context: ConversationContext) -> Optional[dict]:
|
||||||
|
return None
|
@ -0,0 +1,86 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Optional
|
||||||
|
from agentkit.types import BaseMessage
|
||||||
|
|
||||||
|
class ConversationData:
|
||||||
|
KEYS_IN_STORE = ["history", "store", "user_id", "conversation_id", "conversation_rounds"]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Persistent properties
|
||||||
|
self.history: list[BaseMessage] = []
|
||||||
|
"""Chat history"""
|
||||||
|
|
||||||
|
self.store: dict = {}
|
||||||
|
"""Store for plugins"""
|
||||||
|
|
||||||
|
self.user_id: Optional[str] = None
|
||||||
|
"""Current user ID"""
|
||||||
|
|
||||||
|
self.conversation_id: Optional[str] = None
|
||||||
|
"""Current conversation ID"""
|
||||||
|
|
||||||
|
self.conversation_rounds: int = 0
|
||||||
|
"""Number of conversation rounds"""
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {key: getattr(self, key) for key in self.KEYS_IN_STORE}
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationContext(ConversationData):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Temporary properties
|
||||||
|
self.prompt = ""
|
||||||
|
"""User prompt"""
|
||||||
|
|
||||||
|
self.system_prompt = ""
|
||||||
|
"""System prompt"""
|
||||||
|
|
||||||
|
self.session_data: dict = {}
|
||||||
|
"""Temporary data for plugins"""
|
||||||
|
|
||||||
|
self.completion: Optional[str] = None
|
||||||
|
"""Completion result"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_messages_by_role(self, role: str) -> list[BaseMessage]:
|
||||||
|
return [msg for msg in self.history if msg["role"] == role]
|
||||||
|
|
||||||
|
|
||||||
|
def create_sub_context(self, copy_store = False, copy_session_data = False) -> ConversationContext:
|
||||||
|
sub_context = ConversationContext()
|
||||||
|
sub_context.user_id = self.user_id
|
||||||
|
|
||||||
|
if copy_store:
|
||||||
|
sub_context.store = self.store.copy()
|
||||||
|
|
||||||
|
if copy_session_data:
|
||||||
|
sub_context.session_data = self.session_data.copy()
|
||||||
|
|
||||||
|
return sub_context
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.session_data[key]
|
||||||
|
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.session_data[key] = value
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self.session_data[key]
|
||||||
|
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return key in self.session_data
|
||||||
|
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.session_data)
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.session_data)
|
@ -0,0 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
class AgentKitPluginError(Exception):
|
||||||
|
"""Base class for exceptions in the agent plugin module."""
|
||||||
|
def __init__(self, message, code=None):
|
||||||
|
self.message = message
|
||||||
|
self.code = code
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"[{self.code}] ${self.message}"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"AgentPluginError({self.message}, {self.code})"
|
@ -0,0 +1,62 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from typing import List, Dict, Callable
|
||||||
|
from agentkit.types import PluginConfigParam, HookInfo
|
||||||
|
|
||||||
|
|
||||||
|
class AgentKitHooks:
|
||||||
|
HookRegistry: Dict[str, List[HookInfo]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, hook_name: str, class_: Callable,
|
||||||
|
id: str, name: str = "", description: str = "",
|
||||||
|
param_def: list[PluginConfigParam] = []):
|
||||||
|
if hook_name not in cls.HookRegistry:
|
||||||
|
cls.HookRegistry[hook_name] = []
|
||||||
|
|
||||||
|
if name == "":
|
||||||
|
name = id
|
||||||
|
|
||||||
|
cls.HookRegistry[hook_name].append(HookInfo(
|
||||||
|
id=id,
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
params_def=param_def,
|
||||||
|
class_=class_
|
||||||
|
))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, hook_name: str) -> List[HookInfo]:
|
||||||
|
return cls.HookRegistry.get(hook_name, [])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def instantiate(cls, hook_name: str, hook_item_params: dict) -> List[Callable]:
|
||||||
|
needle_hook_ids = list(hook_item_params.keys())
|
||||||
|
instances = []
|
||||||
|
for hook_info in cls.get(hook_name):
|
||||||
|
if hook_info["id"] in needle_hook_ids:
|
||||||
|
params = hook_item_params[hook_info["id"]]
|
||||||
|
try:
|
||||||
|
instance = hook_info["class_"](**params)
|
||||||
|
instances.append(instance)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error instantiating hook {hook_info['id']}: {e}", file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return instances
|
||||||
|
|
||||||
|
def agentkit_preprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
|
||||||
|
param_def: list[PluginConfigParam] = []):
|
||||||
|
AgentKitHooks.register("preprocessor", orig_class, id, name, description, param_def)
|
||||||
|
return orig_class
|
||||||
|
|
||||||
|
def agentkit_postprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
|
||||||
|
param_def: list[PluginConfigParam] = []):
|
||||||
|
AgentKitHooks.register("postprocessor", orig_class, id, name, description, param_def)
|
||||||
|
return orig_class
|
||||||
|
|
||||||
|
def agentkit_func_call_handler(orig_class: Callable, id: str, name: str = "", description: str = "",
|
||||||
|
param_def: list[PluginConfigParam] = []):
|
||||||
|
AgentKitHooks.register("func_call_handler", orig_class, id, name, description, param_def)
|
||||||
|
return orig_class
|
@ -0,0 +1,139 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from typing import Optional
|
||||||
|
from agentkit.errors import AgentKitPluginError
|
||||||
|
from agentkit.hooks import agentkit_preprocessor
|
||||||
|
from agentkit.base.preprocessor import AgentKitPreprocessor
|
||||||
|
from agentkit.types import PluginConfigParam
|
||||||
|
from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService
|
||||||
|
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
|
||||||
|
|
||||||
|
@agentkit_preprocessor(id="mediawiki_embedding_search",
|
||||||
|
name="页面内容提取",
|
||||||
|
description="每次对话前,在页面中提取相关内容")
|
||||||
|
class EmbeddingSearchPreprocessor(AgentKitPreprocessor):
|
||||||
|
CONFIG: list[PluginConfigParam] = [
|
||||||
|
{
|
||||||
|
"id": "distance_limit",
|
||||||
|
"name": "向量查找距离限制",
|
||||||
|
"type": "float",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "extract_limit",
|
||||||
|
"name": "提取数量限制",
|
||||||
|
"type": "int",
|
||||||
|
"required": True,
|
||||||
|
"default": 15,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
OUTPUT: list[PluginConfigParam] = [
|
||||||
|
{
|
||||||
|
"id": "document",
|
||||||
|
"name": "抽取的文档",
|
||||||
|
"type": "string",
|
||||||
|
"default": "",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, props: dict, distance_limit: float,
|
||||||
|
extract_limit: int, **kwargs):
|
||||||
|
self.props = props
|
||||||
|
|
||||||
|
self.distance_limit = distance_limit
|
||||||
|
self.extract_limit = extract_limit
|
||||||
|
|
||||||
|
self.page_title = props.get("page_title", "")
|
||||||
|
self.db_service = props.get("db_service", None)
|
||||||
|
|
||||||
|
if not self.db_service:
|
||||||
|
raise ValueError("db_service is required in props")
|
||||||
|
|
||||||
|
if self.page_title:
|
||||||
|
self.embedding_search = EmbeddingSearchService(self.db_service, self.page_title)
|
||||||
|
else:
|
||||||
|
self.embedding_search = None
|
||||||
|
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
if self.embedding_search:
|
||||||
|
self.embedding_search.__aenter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.embedding_search:
|
||||||
|
self.embedding_search.__aexit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
|
||||||
|
async def on_open_conversation(self, on_progress=None):
|
||||||
|
if not self.embedding_search:
|
||||||
|
return
|
||||||
|
|
||||||
|
mwapi = MediaWikiApi.create()
|
||||||
|
try:
|
||||||
|
if await self.embedding_search.should_update_page_index():
|
||||||
|
if self.props.get("caller_type") == "user":
|
||||||
|
user_id = self.props.get("user")
|
||||||
|
usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
|
||||||
|
transatcion_id = usage_res.get("transaction_id")
|
||||||
|
|
||||||
|
await self.embedding_search.prepare_update_index()
|
||||||
|
|
||||||
|
token_usage = await self.embedding_search.update_page_index(on_progress)
|
||||||
|
|
||||||
|
if transatcion_id:
|
||||||
|
result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
|
||||||
|
except MediaWikiPageNotFoundException:
|
||||||
|
pass
|
||||||
|
except MediaWikiApiException as e:
|
||||||
|
error_msg = "MediaWiki API error: %s" % str(e)
|
||||||
|
print(error_msg, file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if transatcion_id:
|
||||||
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
||||||
|
|
||||||
|
raise AgentKitPluginError(f"MediaWiki API error: ${e.info}", e.code)
|
||||||
|
except EmbeddingRunningException:
|
||||||
|
error_msg = "Page index is running now"
|
||||||
|
|
||||||
|
if transatcion_id:
|
||||||
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
||||||
|
|
||||||
|
raise AgentKitPluginError(error_msg, "page-index-running")
|
||||||
|
except ConnectionResetError:
|
||||||
|
pass # Ignore websocket close error
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
print(error_msg, file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if transatcion_id:
|
||||||
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
||||||
|
|
||||||
|
raise AgentKitPluginError(f"Error preparing page index: ${e}", "page-index-error")
|
||||||
|
|
||||||
|
|
||||||
|
async def on_before_completion(self, prompt: str) -> Optional[dict]:
|
||||||
|
if not self.embedding_search:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
try:
|
||||||
|
extracted_docs = await self.embedding_search.search(prompt, self.extract_limit, True, self.distance_limit)
|
||||||
|
generated_prompt = ""
|
||||||
|
for doc in extracted_docs:
|
||||||
|
generated_prompt += f"- [{doc['title']}]({doc['url']})\n"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"document": generated_prompt
|
||||||
|
}
|
||||||
|
except EmbeddingRunningException:
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
print(error_msg, file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise AgentKitPluginError(f"Error getting related content: ${e}", "related-content-error")
|
@ -0,0 +1,52 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Any, Callable, List, Optional, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigSelectItem(TypedDict):
|
||||||
|
label: str
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
|
||||||
|
class PluginConfigParam(TypedDict):
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
default: Any
|
||||||
|
required: bool
|
||||||
|
min: Optional[float]
|
||||||
|
max: Optional[float]
|
||||||
|
options: Optional[list[ConfigSelectItem]]
|
||||||
|
|
||||||
|
|
||||||
|
class HookInfo(TypedDict):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
class_: Callable
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessage(TypedDict):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
function_call: Optional[Any]
|
||||||
|
tool_calls: Optional[list]
|
||||||
|
tool_call_id: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class LLMFunctionInfo(TypedDict):
|
||||||
|
name_for_llm: str
|
||||||
|
name_for_human: str
|
||||||
|
description: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
|
|
||||||
|
class LLMFunctionResponse(TypedDict):
|
||||||
|
response: str
|
||||||
|
"""The response text generated by the function."""
|
||||||
|
|
||||||
|
direct_return: bool
|
||||||
|
"""Directly return the response to the user, without further processing."""
|
||||||
|
|
||||||
|
|
||||||
|
OnProgressCallback = Callable[[int, int], Any]
|
Loading…
Reference in New Issue