编写AgentKit框架,将AI对话分离到AgentKit
parent
071ab94829
commit
e5a643afd2
@ -0,0 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class AgentKit:
|
||||
def get_agent(self, agent_id: str):
|
||||
pass
|
@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Optional
|
||||
from agentkit.context import ConversationContext
|
||||
from agentkit.types import LLMFunctionInfo, PluginConfigParam
|
||||
|
||||
|
||||
class AgentKitLLMFunction:
|
||||
CONFIG: list[PluginConfigParam] = []
|
||||
|
||||
def __init__(self, props: dict, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
async def get_llm_function_info() -> Optional[LLMFunctionInfo]:
|
||||
"""
|
||||
Get function information for the LLM function.
|
||||
return None to disable the function for this context.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
async def run(self, context: ConversationContext, params: Any) -> str:
|
||||
return "Function not implemented"
|
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from agentkit.context import ConversationContext
|
||||
from agentkit.types import PluginConfigParam
|
||||
|
||||
class AgentKitModelApi:
|
||||
CONFIG: list[PluginConfigParam] = []
|
||||
|
||||
def __init__(self, props: dict, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> AgentKitModelApi:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
return None
|
||||
|
||||
async def chatcomplete(self, context: ConversationContext) -> ConversationContext:
|
||||
return context
|
||||
|
||||
async def chatcomplete_stream(self, context: ConversationContext) -> ConversationContext:
|
||||
return context
|
||||
|
||||
async def get_token_count(self, context: ConversationContext) -> int:
|
||||
return 0
|
@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from agentkit.context import ConversationContext
|
||||
from agentkit.types import PluginConfigParam, OnProgressCallback
|
||||
|
||||
class AgentKitPostprocessor:
|
||||
CONFIG: list[PluginConfigParam] = []
|
||||
OUTPUT: list[PluginConfigParam] = []
|
||||
|
||||
def __init__(self, props: dict, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return None
|
||||
|
||||
async def on_create_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
|
||||
return None
|
||||
|
||||
async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
|
||||
return None
|
||||
|
||||
async def on_after_completion(self, context: ConversationContext) -> Optional[dict]:
|
||||
"""
|
||||
Called after the conversation, before the completion is returned to the user.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def on_after_completion_background(self, context: ConversationContext) -> Optional[dict]:
|
||||
"""
|
||||
Called after the conversation, after the completion is returned to the user.
|
||||
"""
|
||||
return None
|
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from agentkit.context import ConversationContext
|
||||
from agentkit.types import PluginConfigParam, OnProgressCallback
|
||||
|
||||
class AgentKitPreprocessor:
|
||||
CONFIG: list[PluginConfigParam] = []
|
||||
OUTPUT: list[PluginConfigParam] = []
|
||||
|
||||
def __init__(self, props: dict, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return None
|
||||
|
||||
async def on_create_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
|
||||
return None
|
||||
|
||||
async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
|
||||
return None
|
||||
|
||||
async def on_before_completion(self, context: ConversationContext) -> Optional[dict]:
|
||||
return None
|
@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from agentkit.types import BaseMessage
|
||||
|
||||
class ConversationData:
|
||||
KEYS_IN_STORE = ["history", "store", "user_id", "conversation_id", "conversation_rounds"]
|
||||
|
||||
def __init__(self):
|
||||
# Persistent properties
|
||||
self.history: list[BaseMessage] = []
|
||||
"""Chat history"""
|
||||
|
||||
self.store: dict = {}
|
||||
"""Store for plugins"""
|
||||
|
||||
self.user_id: Optional[str] = None
|
||||
"""Current user ID"""
|
||||
|
||||
self.conversation_id: Optional[str] = None
|
||||
"""Current conversation ID"""
|
||||
|
||||
self.conversation_rounds: int = 0
|
||||
"""Number of conversation rounds"""
|
||||
|
||||
def to_dict(self):
|
||||
return {key: getattr(self, key) for key in self.KEYS_IN_STORE}
|
||||
|
||||
|
||||
class ConversationContext(ConversationData):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Temporary properties
|
||||
self.prompt = ""
|
||||
"""User prompt"""
|
||||
|
||||
self.system_prompt = ""
|
||||
"""System prompt"""
|
||||
|
||||
self.session_data: dict = {}
|
||||
"""Temporary data for plugins"""
|
||||
|
||||
self.completion: Optional[str] = None
|
||||
"""Completion result"""
|
||||
|
||||
|
||||
def get_messages_by_role(self, role: str) -> list[BaseMessage]:
|
||||
return [msg for msg in self.history if msg["role"] == role]
|
||||
|
||||
|
||||
def create_sub_context(self, copy_store = False, copy_session_data = False) -> ConversationContext:
|
||||
sub_context = ConversationContext()
|
||||
sub_context.user_id = self.user_id
|
||||
|
||||
if copy_store:
|
||||
sub_context.store = self.store.copy()
|
||||
|
||||
if copy_session_data:
|
||||
sub_context.session_data = self.session_data.copy()
|
||||
|
||||
return sub_context
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.session_data[key]
|
||||
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.session_data[key] = value
|
||||
return value
|
||||
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.session_data[key]
|
||||
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.session_data
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.session_data)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.session_data)
|
@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class AgentKitPluginError(Exception):
|
||||
"""Base class for exceptions in the agent plugin module."""
|
||||
def __init__(self, message, code=None):
|
||||
self.message = message
|
||||
self.code = code
|
||||
|
||||
def __str__(self):
|
||||
return f"[{self.code}] ${self.message}"
|
||||
|
||||
def __repr__(self):
|
||||
return f"AgentPluginError({self.message}, {self.code})"
|
@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Dict, Callable
|
||||
from agentkit.types import PluginConfigParam, HookInfo
|
||||
|
||||
|
||||
class AgentKitHooks:
|
||||
HookRegistry: Dict[str, List[HookInfo]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, hook_name: str, class_: Callable,
|
||||
id: str, name: str = "", description: str = "",
|
||||
param_def: list[PluginConfigParam] = []):
|
||||
if hook_name not in cls.HookRegistry:
|
||||
cls.HookRegistry[hook_name] = []
|
||||
|
||||
if name == "":
|
||||
name = id
|
||||
|
||||
cls.HookRegistry[hook_name].append(HookInfo(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
params_def=param_def,
|
||||
class_=class_
|
||||
))
|
||||
|
||||
@classmethod
|
||||
def get(cls, hook_name: str) -> List[HookInfo]:
|
||||
return cls.HookRegistry.get(hook_name, [])
|
||||
|
||||
@classmethod
|
||||
def instantiate(cls, hook_name: str, hook_item_params: dict) -> List[Callable]:
|
||||
needle_hook_ids = list(hook_item_params.keys())
|
||||
instances = []
|
||||
for hook_info in cls.get(hook_name):
|
||||
if hook_info["id"] in needle_hook_ids:
|
||||
params = hook_item_params[hook_info["id"]]
|
||||
try:
|
||||
instance = hook_info["class_"](**params)
|
||||
instances.append(instance)
|
||||
except Exception as e:
|
||||
print(f"Error instantiating hook {hook_info['id']}: {e}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
return instances
|
||||
|
||||
def agentkit_preprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
|
||||
param_def: list[PluginConfigParam] = []):
|
||||
AgentKitHooks.register("preprocessor", orig_class, id, name, description, param_def)
|
||||
return orig_class
|
||||
|
||||
def agentkit_postprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
|
||||
param_def: list[PluginConfigParam] = []):
|
||||
AgentKitHooks.register("postprocessor", orig_class, id, name, description, param_def)
|
||||
return orig_class
|
||||
|
||||
def agentkit_func_call_handler(orig_class: Callable, id: str, name: str = "", description: str = "",
|
||||
param_def: list[PluginConfigParam] = []):
|
||||
AgentKitHooks.register("func_call_handler", orig_class, id, name, description, param_def)
|
||||
return orig_class
|
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from agentkit.errors import AgentKitPluginError
|
||||
from agentkit.hooks import agentkit_preprocessor
|
||||
from agentkit.base.preprocessor import AgentKitPreprocessor
|
||||
from agentkit.types import PluginConfigParam
|
||||
from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService
|
||||
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
|
||||
|
||||
@agentkit_preprocessor(id="mediawiki_embedding_search",
|
||||
name="页面内容提取",
|
||||
description="每次对话前,在页面中提取相关内容")
|
||||
class EmbeddingSearchPreprocessor(AgentKitPreprocessor):
|
||||
CONFIG: list[PluginConfigParam] = [
|
||||
{
|
||||
"id": "distance_limit",
|
||||
"name": "向量查找距离限制",
|
||||
"type": "float",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"id": "extract_limit",
|
||||
"name": "提取数量限制",
|
||||
"type": "int",
|
||||
"required": True,
|
||||
"default": 15,
|
||||
}
|
||||
]
|
||||
|
||||
OUTPUT: list[PluginConfigParam] = [
|
||||
{
|
||||
"id": "document",
|
||||
"name": "抽取的文档",
|
||||
"type": "string",
|
||||
"default": "",
|
||||
}
|
||||
]
|
||||
|
||||
def __init__(self, props: dict, distance_limit: float,
|
||||
extract_limit: int, **kwargs):
|
||||
self.props = props
|
||||
|
||||
self.distance_limit = distance_limit
|
||||
self.extract_limit = extract_limit
|
||||
|
||||
self.page_title = props.get("page_title", "")
|
||||
self.db_service = props.get("db_service", None)
|
||||
|
||||
if not self.db_service:
|
||||
raise ValueError("db_service is required in props")
|
||||
|
||||
if self.page_title:
|
||||
self.embedding_search = EmbeddingSearchService(self.db_service, self.page_title)
|
||||
else:
|
||||
self.embedding_search = None
|
||||
|
||||
|
||||
async def __aenter__(self):
|
||||
if self.embedding_search:
|
||||
self.embedding_search.__aenter__()
|
||||
return self
|
||||
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.embedding_search:
|
||||
self.embedding_search.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
async def on_open_conversation(self, on_progress=None):
|
||||
if not self.embedding_search:
|
||||
return
|
||||
|
||||
mwapi = MediaWikiApi.create()
|
||||
try:
|
||||
if await self.embedding_search.should_update_page_index():
|
||||
if self.props.get("caller_type") == "user":
|
||||
user_id = self.props.get("user")
|
||||
usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
|
||||
transatcion_id = usage_res.get("transaction_id")
|
||||
|
||||
await self.embedding_search.prepare_update_index()
|
||||
|
||||
token_usage = await self.embedding_search.update_page_index(on_progress)
|
||||
|
||||
if transatcion_id:
|
||||
result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
|
||||
except MediaWikiPageNotFoundException:
|
||||
pass
|
||||
except MediaWikiApiException as e:
|
||||
error_msg = "MediaWiki API error: %s" % str(e)
|
||||
print(error_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
if transatcion_id:
|
||||
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
||||
|
||||
raise AgentKitPluginError(f"MediaWiki API error: ${e.info}", e.code)
|
||||
except EmbeddingRunningException:
|
||||
error_msg = "Page index is running now"
|
||||
|
||||
if transatcion_id:
|
||||
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
||||
|
||||
raise AgentKitPluginError(error_msg, "page-index-running")
|
||||
except ConnectionResetError:
|
||||
pass # Ignore websocket close error
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(error_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
if transatcion_id:
|
||||
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
||||
|
||||
raise AgentKitPluginError(f"Error preparing page index: ${e}", "page-index-error")
|
||||
|
||||
|
||||
async def on_before_completion(self, prompt: str) -> Optional[dict]:
|
||||
if not self.embedding_search:
|
||||
return prompt
|
||||
|
||||
try:
|
||||
extracted_docs = await self.embedding_search.search(prompt, self.extract_limit, True, self.distance_limit)
|
||||
generated_prompt = ""
|
||||
for doc in extracted_docs:
|
||||
generated_prompt += f"- [{doc['title']}]({doc['url']})\n"
|
||||
|
||||
return {
|
||||
"document": generated_prompt
|
||||
}
|
||||
except EmbeddingRunningException:
|
||||
return
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(error_msg, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
raise AgentKitPluginError(f"Error getting related content: ${e}", "related-content-error")
|
@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Callable, List, Optional, TypedDict
|
||||
|
||||
|
||||
class ConfigSelectItem(TypedDict):
|
||||
label: str
|
||||
value: Any
|
||||
|
||||
|
||||
class PluginConfigParam(TypedDict):
|
||||
id: str
|
||||
type: str
|
||||
name: str
|
||||
description: str
|
||||
default: Any
|
||||
required: bool
|
||||
min: Optional[float]
|
||||
max: Optional[float]
|
||||
options: Optional[list[ConfigSelectItem]]
|
||||
|
||||
|
||||
class HookInfo(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
class_: Callable
|
||||
|
||||
|
||||
class BaseMessage(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
function_call: Optional[Any]
|
||||
tool_calls: Optional[list]
|
||||
tool_call_id: Optional[str]
|
||||
|
||||
|
||||
class LLMFunctionInfo(TypedDict):
|
||||
name_for_llm: str
|
||||
name_for_human: str
|
||||
description: str
|
||||
params: dict
|
||||
|
||||
|
||||
class LLMFunctionResponse(TypedDict):
|
||||
response: str
|
||||
"""The response text generated by the function."""
|
||||
|
||||
direct_return: bool
|
||||
"""Directly return the response to the user, without further processing."""
|
||||
|
||||
|
||||
OnProgressCallback = Callable[[int, int], Any]
|
Loading…
Reference in New Issue