from __future__ import annotations import time import sqlalchemy from sqlalchemy import select, update from sqlalchemy.orm import mapped_column, relationship, Mapped from api.model.base import BaseHelper, BaseModel from api.model.toolkit_ui.conversation import ConversationModel from service.database import DatabaseService from service.event import EventService class ConversationChunkModel(BaseModel): __tablename__ = "chat_complete_conversation_chunk" id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) conversation_id: Mapped[int] = mapped_column( sqlalchemy.ForeignKey(ConversationModel.id, ondelete="CASCADE", onupdate="CASCADE"), index=True) message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True) tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0) updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True) class ConversationChunkHelper(BaseHelper): async def add(self, obj: ConversationChunkModel): obj.updated_at = int(time.time()) self.session.add(obj) await self.session.commit() await self.session.refresh(obj) return obj async def update(self, obj: ConversationChunkModel): obj.updated_at = int(time.time()) obj = await self.session.merge(obj) await self.session.commit() return obj async def update_message_log(self, chunk_id: int, message_data: list, tokens: int): stmt = update(ConversationChunkModel).where(ConversationChunkModel.id == chunk_id) \ .values(message_data=message_data, tokens=tokens, updated_at=int(time.time())) await self.session.execute(stmt) await self.session.commit() async def get_newest_chunk(self, conversation_id: int): stmt = select(ConversationChunkModel) \ .where(ConversationChunkModel.conversation_id == conversation_id) \ .order_by(ConversationChunkModel.id.desc()) \ .limit(1) return await self.session.scalar(stmt) async def get_chunk_id_list(self, conversation_id: int): stmt = select(ConversationChunkModel.id) \ .where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc()) return await self.session.scalars(stmt) async def find_by_id(self, id: int): stmt = select(ConversationChunkModel).where(ConversationChunkModel.id == id) return await self.session.scalar(stmt) async def remove(self, id: int | list[int]): if isinstance(id, list): stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id.in_(id)) else: stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id == id) await self.session.execute(stmt) await self.session.commit() async def remove_by_conversation_ids(self, ids: list[int]): stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.conversation_id.in_(ids)) await self.session.execute(stmt) await self.session.commit() async def on_conversation_removed(event): if "ids" in event: conversation_ids = event["ids"] async with ConversationChunkHelper(event["dbs"]) as chunk_helper: await chunk_helper.remove_by_conversation_ids(conversation_ids) EventService.create().add_listener("conversation/removed", on_conversation_removed)