from __future__ import annotations
import hashlib
import asyncpg
import numpy as np
from pgvector.sqlalchemy import Vector
from pgvector.asyncpg import register_vector
import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer
from sqlalchemy.ext.asyncio import AsyncEngine

from config import Config
from api.model.base import BaseHelper, BaseModel
from service.database import DatabaseService

embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int)

class TitleIndexModel(BaseModel):
    __tablename__ = "embedding_search_title_index"

    id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
    sha1: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
    title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
    page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
    collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
    indexed_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
    latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
    embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(embedding_vector_size), nullable=True))

    embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding,
                                       postgresql_using='ivfflat',
                                       postgresql_ops={'embedding': 'vector_cosine_ops'})

class TitleIndexHelper(BaseHelper):
    __tablename__ = "embedding_search_title_index"

    columns = [
        "id",
        "sha1",
        "title",
        "page_id",
        "collection_id",
        "indexed_rev_id",
        "latest_rev_id"
        "embedding",
    ]

    def __init__(self, dbs: DatabaseService):
        super().__init__(dbs)

    async def __aenter__(self):
        if not self.initialized:
            self.dbpool = self.dbs.pool.acquire()
            self.dbi = await self.dbpool.__aenter__()

            await register_vector(self.dbi)

        return await super().__aenter__()

    async def __aexit__(self, exc_type, exc, tb):
        await self.dbpool.__aexit__(exc_type, exc, tb)
        return await super().__aexit__(exc_type, exc, tb)

    def get_columns(self, exclude=[]):
        if len(exclude) == 0:
            return ", ".join(self.columns)

        return ", ".join([col for col in self.columns if col not in exclude])

    """
    Add a title to the index
    """
    async def add(self, obj: TitleIndexModel):
        obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest()
        
        if await self.find_by_sha1(obj.sha1) is None:
            self.session.add(obj)
            await self.session.commit()
            await self.session.refresh(obj)
            return obj

        return False

    """
    Remove a title from the index
    """
    async def remove(self, title: str):
        title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
        stmt = sqlalchemy.delete(TitleIndexModel).where(TitleIndexModel.sha1 == title_sha1)
        await self.session.execute(stmt)

    async def update(self, obj: TitleIndexModel):
        obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest()
        await self.session.merge(obj)
        await self.session.commit()
        await self.session.refresh(obj)
        return obj
    
    async def refresh(self, obj: TitleIndexModel):
        await self.session.refresh(obj)
        return obj

    """
    Search for titles by consine similary
    """
    async def search_title_embedding(self, embedding: np.ndarray, limit: int = 10):
        ret = self.dbi.fetch("""SELECT %s, embedding <-> $1 AS distance
            FROM embedding_search_title_index
            ORDER BY distance DESC
            LIMIT %d""" % (self.get_columns(exclude=['embedding']), limit),
            embedding)

        return ret

    """
    Find a title in the index
    """
    async def find_by_title(self, title: str):
        sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
        return await self.find_by_sha1(sha1)
    
    async def find_by_sha1(self, sha1: str):
        stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.sha1 == sha1)
        return await self.session.scalar(stmt)

    async def find_by_page_id(self, page_id: int):
        stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.page_id == page_id)
        return await self.session.scalar(stmt)

    async def find_list_by_collection_id(self, collection_id: int) -> sqlalchemy.ScalarResult[TitleIndexModel]:
        stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.collection_id == collection_id)
        return await self.session.scalars(stmt)
    
    async def get_need_update_index_list(self, collection_id: int) -> list[TitleIndexModel]:
        page_list = await self.find_list_by_collection_id(collection_id)
        result: list[TitleIndexModel] = []
        for page_info in page_list:
            if page_info.indexed_rev_id != page_info.latest_rev_id:
                result.append(page_info)

        return result