You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
90 lines
2.2 KiB
Python
90 lines
2.2 KiB
Python
from __future__ import annotations
|
|
from typing import Optional
|
|
from agentkit.types import BaseMessage
|
|
from agentkit.utils.progress import AgentKitTaskProgress
|
|
|
|
class ConversationData:
|
|
KEYS_IN_STORE = ["history", "store", "user_id", "conversation_id", "conversation_rounds"]
|
|
|
|
def __init__(self):
|
|
# Persistent properties
|
|
self.history: list[BaseMessage] = []
|
|
"""Chat history"""
|
|
|
|
self.store: dict = {}
|
|
"""Store for plugins"""
|
|
|
|
self.user_id: Optional[str] = None
|
|
"""Current user ID"""
|
|
|
|
self.conversation_id: Optional[str] = None
|
|
"""Current conversation ID"""
|
|
|
|
self.conversation_rounds: int = 0
|
|
"""Number of conversation rounds"""
|
|
|
|
def to_dict(self):
|
|
return {key: getattr(self, key) for key in self.KEYS_IN_STORE}
|
|
|
|
|
|
class ConversationContext(ConversationData):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
# Temporary properties
|
|
self.prompt = ""
|
|
"""User prompt"""
|
|
|
|
self.system_prompt = ""
|
|
"""System prompt"""
|
|
|
|
self.session_data: dict = {}
|
|
"""Temporary data for plugins"""
|
|
|
|
self.completion: Optional[str] = None
|
|
"""Completion result"""
|
|
|
|
self.progress = AgentKitTaskProgress()
|
|
"""Task progress"""
|
|
|
|
|
|
def get_messages_by_role(self, role: str) -> list[BaseMessage]:
|
|
return [msg for msg in self.history if msg["role"] == role]
|
|
|
|
|
|
def create_sub_context(self, copy_store = False, copy_session_data = False) -> ConversationContext:
|
|
sub_context = ConversationContext()
|
|
sub_context.user_id = self.user_id
|
|
|
|
if copy_store:
|
|
sub_context.store = self.store.copy()
|
|
|
|
if copy_session_data:
|
|
sub_context.session_data = self.session_data.copy()
|
|
|
|
return sub_context
|
|
|
|
|
|
def __getitem__(self, key):
|
|
return self.session_data[key]
|
|
|
|
|
|
def __setitem__(self, key, value):
|
|
self.session_data[key] = value
|
|
return value
|
|
|
|
|
|
def __delitem__(self, key):
|
|
del self.session_data[key]
|
|
|
|
|
|
def __contains__(self, key):
|
|
return key in self.session_data
|
|
|
|
|
|
def __iter__(self):
|
|
return iter(self.session_data)
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.session_data) |