编写AgentKit框架,将AI对话分离到AgentKit

master
落雨楓 2 months ago
parent 071ab94829
commit e5a643afd2

@ -0,0 +1,6 @@
from __future__ import annotations
class AgentKit:
def get_agent(self, agent_id: str):
pass

@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Any, Optional
from agentkit.context import ConversationContext
from agentkit.types import LLMFunctionInfo, PluginConfigParam
class AgentKitLLMFunction:
CONFIG: list[PluginConfigParam] = []
def __init__(self, props: dict, **kwargs):
pass
async def get_llm_function_info() -> Optional[LLMFunctionInfo]:
"""
Get function information for the LLM function.
return None to disable the function for this context.
"""
return None
async def run(self, context: ConversationContext, params: Any) -> str:
return "Function not implemented"

@ -0,0 +1,26 @@
from __future__ import annotations
from typing import Any
from agentkit.context import ConversationContext
from agentkit.types import PluginConfigParam
class AgentKitModelApi:
CONFIG: list[PluginConfigParam] = []
def __init__(self, props: dict, **kwargs):
pass
async def __aenter__(self) -> AgentKitModelApi:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
return None
async def chatcomplete(self, context: ConversationContext) -> ConversationContext:
return context
async def chatcomplete_stream(self, context: ConversationContext) -> ConversationContext:
return context
async def get_token_count(self, context: ConversationContext) -> int:
return 0

@ -0,0 +1,35 @@
from __future__ import annotations
from typing import Optional
from agentkit.context import ConversationContext
from agentkit.types import PluginConfigParam, OnProgressCallback
class AgentKitPostprocessor:
CONFIG: list[PluginConfigParam] = []
OUTPUT: list[PluginConfigParam] = []
def __init__(self, props: dict, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return None
async def on_create_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
return None
async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
return None
async def on_after_completion(self, context: ConversationContext) -> Optional[dict]:
"""
Called after the conversation, before the completion is returned to the user.
"""
return None
async def on_after_completion_background(self, context: ConversationContext) -> Optional[dict]:
"""
Called after the conversation, after the completion is returned to the user.
"""
return None

@ -0,0 +1,26 @@
from __future__ import annotations
from typing import Optional
from agentkit.context import ConversationContext
from agentkit.types import PluginConfigParam, OnProgressCallback
class AgentKitPreprocessor:
CONFIG: list[PluginConfigParam] = []
OUTPUT: list[PluginConfigParam] = []
def __init__(self, props: dict, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return None
async def on_create_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
return None
async def on_open_conversation(self, on_progress: Optional[OnProgressCallback] = None) -> None:
return None
async def on_before_completion(self, context: ConversationContext) -> Optional[dict]:
return None

@ -0,0 +1,86 @@
from __future__ import annotations
from typing import Optional
from agentkit.types import BaseMessage
class ConversationData:
KEYS_IN_STORE = ["history", "store", "user_id", "conversation_id", "conversation_rounds"]
def __init__(self):
# Persistent properties
self.history: list[BaseMessage] = []
"""Chat history"""
self.store: dict = {}
"""Store for plugins"""
self.user_id: Optional[str] = None
"""Current user ID"""
self.conversation_id: Optional[str] = None
"""Current conversation ID"""
self.conversation_rounds: int = 0
"""Number of conversation rounds"""
def to_dict(self):
return {key: getattr(self, key) for key in self.KEYS_IN_STORE}
class ConversationContext(ConversationData):
def __init__(self):
super().__init__()
# Temporary properties
self.prompt = ""
"""User prompt"""
self.system_prompt = ""
"""System prompt"""
self.session_data: dict = {}
"""Temporary data for plugins"""
self.completion: Optional[str] = None
"""Completion result"""
def get_messages_by_role(self, role: str) -> list[BaseMessage]:
return [msg for msg in self.history if msg["role"] == role]
def create_sub_context(self, copy_store = False, copy_session_data = False) -> ConversationContext:
sub_context = ConversationContext()
sub_context.user_id = self.user_id
if copy_store:
sub_context.store = self.store.copy()
if copy_session_data:
sub_context.session_data = self.session_data.copy()
return sub_context
def __getitem__(self, key):
return self.session_data[key]
def __setitem__(self, key, value):
self.session_data[key] = value
return value
def __delitem__(self, key):
del self.session_data[key]
def __contains__(self, key):
return key in self.session_data
def __iter__(self):
return iter(self.session_data)
def __len__(self):
return len(self.session_data)

@ -0,0 +1,14 @@
from __future__ import annotations
class AgentKitPluginError(Exception):
"""Base class for exceptions in the agent plugin module."""
def __init__(self, message, code=None):
self.message = message
self.code = code
def __str__(self):
return f"[{self.code}] ${self.message}"
def __repr__(self):
return f"AgentPluginError({self.message}, {self.code})"

@ -0,0 +1,62 @@
from __future__ import annotations
import sys
import traceback
from typing import List, Dict, Callable
from agentkit.types import PluginConfigParam, HookInfo
class AgentKitHooks:
HookRegistry: Dict[str, List[HookInfo]] = {}
@classmethod
def register(cls, hook_name: str, class_: Callable,
id: str, name: str = "", description: str = "",
param_def: list[PluginConfigParam] = []):
if hook_name not in cls.HookRegistry:
cls.HookRegistry[hook_name] = []
if name == "":
name = id
cls.HookRegistry[hook_name].append(HookInfo(
id=id,
name=name,
description=description,
params_def=param_def,
class_=class_
))
@classmethod
def get(cls, hook_name: str) -> List[HookInfo]:
return cls.HookRegistry.get(hook_name, [])
@classmethod
def instantiate(cls, hook_name: str, hook_item_params: dict) -> List[Callable]:
needle_hook_ids = list(hook_item_params.keys())
instances = []
for hook_info in cls.get(hook_name):
if hook_info["id"] in needle_hook_ids:
params = hook_item_params[hook_info["id"]]
try:
instance = hook_info["class_"](**params)
instances.append(instance)
except Exception as e:
print(f"Error instantiating hook {hook_info['id']}: {e}", file=sys.stderr)
traceback.print_exc()
return instances
def agentkit_preprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
param_def: list[PluginConfigParam] = []):
AgentKitHooks.register("preprocessor", orig_class, id, name, description, param_def)
return orig_class
def agentkit_postprocessor(orig_class: Callable, id: str, name: str = "", description: str = "",
param_def: list[PluginConfigParam] = []):
AgentKitHooks.register("postprocessor", orig_class, id, name, description, param_def)
return orig_class
def agentkit_func_call_handler(orig_class: Callable, id: str, name: str = "", description: str = "",
param_def: list[PluginConfigParam] = []):
AgentKitHooks.register("func_call_handler", orig_class, id, name, description, param_def)
return orig_class

@ -0,0 +1,139 @@
from __future__ import annotations
import sys
import traceback
from typing import Optional
from agentkit.errors import AgentKitPluginError
from agentkit.hooks import agentkit_preprocessor
from agentkit.base.preprocessor import AgentKitPreprocessor
from agentkit.types import PluginConfigParam
from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
@agentkit_preprocessor(id="mediawiki_embedding_search",
name="页面内容提取",
description="每次对话前,在页面中提取相关内容")
class EmbeddingSearchPreprocessor(AgentKitPreprocessor):
CONFIG: list[PluginConfigParam] = [
{
"id": "distance_limit",
"name": "向量查找距离限制",
"type": "float",
"required": True,
},
{
"id": "extract_limit",
"name": "提取数量限制",
"type": "int",
"required": True,
"default": 15,
}
]
OUTPUT: list[PluginConfigParam] = [
{
"id": "document",
"name": "抽取的文档",
"type": "string",
"default": "",
}
]
def __init__(self, props: dict, distance_limit: float,
extract_limit: int, **kwargs):
self.props = props
self.distance_limit = distance_limit
self.extract_limit = extract_limit
self.page_title = props.get("page_title", "")
self.db_service = props.get("db_service", None)
if not self.db_service:
raise ValueError("db_service is required in props")
if self.page_title:
self.embedding_search = EmbeddingSearchService(self.db_service, self.page_title)
else:
self.embedding_search = None
async def __aenter__(self):
if self.embedding_search:
self.embedding_search.__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.embedding_search:
self.embedding_search.__aexit__(exc_type, exc_val, exc_tb)
async def on_open_conversation(self, on_progress=None):
if not self.embedding_search:
return
mwapi = MediaWikiApi.create()
try:
if await self.embedding_search.should_update_page_index():
if self.props.get("caller_type") == "user":
user_id = self.props.get("user")
usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id")
await self.embedding_search.prepare_update_index()
token_usage = await self.embedding_search.update_page_index(on_progress)
if transatcion_id:
result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
except MediaWikiPageNotFoundException:
pass
except MediaWikiApiException as e:
error_msg = "MediaWiki API error: %s" % str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
raise AgentKitPluginError(f"MediaWiki API error: ${e.info}", e.code)
except EmbeddingRunningException:
error_msg = "Page index is running now"
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
raise AgentKitPluginError(error_msg, "page-index-running")
except ConnectionResetError:
pass # Ignore websocket close error
except Exception as e:
error_msg = str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
raise AgentKitPluginError(f"Error preparing page index: ${e}", "page-index-error")
async def on_before_completion(self, prompt: str) -> Optional[dict]:
if not self.embedding_search:
return prompt
try:
extracted_docs = await self.embedding_search.search(prompt, self.extract_limit, True, self.distance_limit)
generated_prompt = ""
for doc in extracted_docs:
generated_prompt += f"- [{doc['title']}]({doc['url']})\n"
return {
"document": generated_prompt
}
except EmbeddingRunningException:
return
except Exception as e:
error_msg = str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
raise AgentKitPluginError(f"Error getting related content: ${e}", "related-content-error")

@ -0,0 +1,52 @@
from __future__ import annotations
from typing import Any, Callable, List, Optional, TypedDict
class ConfigSelectItem(TypedDict):
label: str
value: Any
class PluginConfigParam(TypedDict):
id: str
type: str
name: str
description: str
default: Any
required: bool
min: Optional[float]
max: Optional[float]
options: Optional[list[ConfigSelectItem]]
class HookInfo(TypedDict):
id: str
name: str
description: str
class_: Callable
class BaseMessage(TypedDict):
role: str
content: str
function_call: Optional[Any]
tool_calls: Optional[list]
tool_call_id: Optional[str]
class LLMFunctionInfo(TypedDict):
name_for_llm: str
name_for_human: str
description: str
params: dict
class LLMFunctionResponse(TypedDict):
response: str
"""The response text generated by the function."""
direct_return: bool
"""Directly return the response to the user, without further processing."""
OnProgressCallback = Callable[[int, int], Any]

@ -127,9 +127,8 @@ class PageIndexHelper:
self.table_initialized = True
async def create_embedding_index(self):
pass
# await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);"
# .replace("/*_*/", self.table_name))
await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);"
.replace("/*_*/", self.table_name))
def sha1_doc(self, doc: list):
for item in doc:

@ -409,7 +409,7 @@ class ChatCompleteService:
chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
api_id = Config.get("chatcomplete.system_api_id", "default", str)
model_id = Config.get("chatcomplete.system_api_id", "default", str)
model_id = Config.get("chatcomplete.system_model_id", "default", str)
openai_api = OpenAIApi.create(api_id)
@ -442,7 +442,7 @@ class ChatCompleteService:
chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
api_id = Config.get("chatcomplete.system_api_id", "default", str)
model_id = Config.get("chatcomplete.system_api_id", "default", str)
model_id = Config.get("chatcomplete.system_model_id", "default", str)
openai_api = OpenAIApi.create(api_id)

Loading…
Cancel
Save