from __future__ import annotations import sqlalchemy from sqlalchemy import update from sqlalchemy.orm import mapped_column, relationship, Mapped from api.model.base import BaseModel from api.model.toolkit_ui.conversation import ConversationModel from service.database import DatabaseService from service.event import EventService class ConversationChunkModel(BaseModel): __tablename__ = "chat_complete_conversation_chunk" id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) conversation_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(ConversationModel.id), index=True) message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True) tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0) updated_at: Mapped[int] = mapped_column(sqlalchemy.TIMESTAMP, index=True) class ConversationChunkHelper: def __init__(self, dbs: DatabaseService): self.dbs = dbs self.initialized = False async def __aenter__(self): if not self.initialized: self.create_session = self.dbs.create_session self.session = self.dbs.create_session() await self.session.__aenter__() self.initialized = True return self async def __aexit__(self, exc_type, exc, tb): await self.session.__aexit__(exc_type, exc, tb) pass async def add(self, conversation_id: int, message_data: list, tokens: int): async with self.create_session() as session: chunk = ConversationChunkModel( conversation_id=conversation_id, message_data=message_data, tokens=tokens, updated_at=sqlalchemy.func.current_timestamp() ) session.add(chunk) await session.commit() await session.refresh(chunk) return chunk async def update(self, chunk: ConversationChunkModel): chunk.updated_at = sqlalchemy.func.current_timestamp() chunk = await self.session.merge(chunk) await self.session.commit() return chunk async def update_message_log(self, chunk_id: int, message_data: list, tokens: int): stmt = update(ConversationChunkModel).where(ConversationChunkModel.id == chunk_id) \ .values(message_data=message_data, tokens=tokens, updated_at=sqlalchemy.func.current_timestamp()) await self.session.execute(stmt) await self.session.commit() async def get_newest_chunk(self, conversation_id: int): stmt = sqlalchemy.select(ConversationChunkModel) \ .where(ConversationChunkModel.conversation_id == conversation_id) \ .order_by(ConversationChunkModel.id.desc()) \ .limit(1) return await self.session.scalar(stmt) async def remove(self, id: int): stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id == id) await self.session.execute(stmt) await self.session.commit() async def remove_by_conversation_id(self, conversation_id: int): stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.conversation_id == conversation_id) await self.session.execute(stmt) await self.session.commit() async def on_conversation_removed(event): if "conversation" in event: conversation_info = event["conversation"] conversation_id = conversation_info["id"] await ConversationChunkHelper(event["dbs"]).remove_by_conversation_id(conversation_id) EventService.create().add_listener("conversation/removed/chatcomplete", on_conversation_removed)