from __future__ import annotations
import time

import sqlalchemy
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, Mapped

from api.model.base import BaseHelper, BaseModel
from api.model.toolkit_ui.conversation import ConversationModel
from service.database import DatabaseService
from service.event import EventService

class ConversationChunkModel(BaseModel):
    __tablename__ = "chat_complete_conversation_chunk"

    id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
    conversation_id: Mapped[int] = mapped_column(
        sqlalchemy.ForeignKey(ConversationModel.id, ondelete="CASCADE", onupdate="CASCADE"), index=True)
    message_data: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True)
    tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, default=0)
    updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)

class ConversationChunkHelper(BaseHelper):
    async def add(self, obj: ConversationChunkModel):
        obj.updated_at = int(time.time())
        self.session.add(obj)
        await self.session.commit()
        await self.session.refresh(obj)
        return obj
        
    async def update(self, obj: ConversationChunkModel):
        obj.updated_at = int(time.time())
        obj = await self.session.merge(obj)
        await self.session.commit()
        return obj

    async def update_message_log(self, chunk_id: int, message_data: list, tokens: int):
        stmt = update(ConversationChunkModel).where(ConversationChunkModel.id == chunk_id) \
            .values(message_data=message_data, tokens=tokens, updated_at=int(time.time()))
        await self.session.execute(stmt)
        await self.session.commit()

    async def get_newest_chunk(self, conversation_id: int):
        stmt = select(ConversationChunkModel) \
                .where(ConversationChunkModel.conversation_id == conversation_id) \
                .order_by(ConversationChunkModel.id.desc()) \
                .limit(1)
        return await self.session.scalar(stmt)

    async def get_chunk_id_list(self, conversation_id: int):
        stmt = select(ConversationChunkModel.id) \
                .where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())
        return await self.session.scalars(stmt)
    
    async def find_by_id(self, id: int):
        stmt = select(ConversationChunkModel).where(ConversationChunkModel.id == id)
        return await self.session.scalar(stmt)
    
    async def remove(self, id: int | list[int]):
        if isinstance(id, list):
            stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id.in_(id))
        else:
            stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.id == id)
        await self.session.execute(stmt)
        await self.session.commit()

    async def remove_by_conversation_ids(self, ids: list[int]):
        stmt = sqlalchemy.delete(ConversationChunkModel).where(ConversationChunkModel.conversation_id.in_(ids))
        await self.session.execute(stmt)
        await self.session.commit()
    
async def on_conversation_removed(event):
    if "ids" in event:
        conversation_ids = event["ids"]
        async with ConversationChunkHelper(event["dbs"]) as chunk_helper:
            await chunk_helper.remove_by_conversation_ids(conversation_ids)

EventService.create().add_listener("conversation/removed", on_conversation_removed)