from __future__ import annotations
import hashlib
from typing import Optional, Type

import asyncpg
from api.model.base import BaseModel
import config
import numpy as np
import sqlalchemy
from sqlalchemy import Index, select, update, delete, Select
from sqlalchemy.orm import mapped_column, Mapped
from sqlalchemy.ext.asyncio import AsyncSession
from pgvector.asyncpg import register_vector
from pgvector.sqlalchemy import Vector

from service.database import DatabaseService

page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}


class AbstractPageIndexModel(BaseModel):
    __abstract__ = True

    id: Mapped[int] = mapped_column(
        sqlalchemy.Integer, primary_key=True, autoincrement=True
    )
    page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
    sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True)
    embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
    text: Mapped[str] = mapped_column(sqlalchemy.Text)
    text_len: Mapped[int] = mapped_column(sqlalchemy.Integer)
    markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True)
    markdown_len: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
    temp_doc_session_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)


def create_page_index_model(collection_id: int):
    if collection_id in page_index_model_list:
        return page_index_model_list[collection_id]
    else:

        class _PageIndexModel(AbstractPageIndexModel):
            __tablename__ = "embedding_search_page_index_%s" % str(collection_id)

            embedding_index = sqlalchemy.Index(
                __tablename__ + "_embedding_idx",
                AbstractPageIndexModel.embedding,
                postgresql_using="ivfflat",
                postgresql_ops={"embedding": "vector_cosine_ops"},
            )

        page_index_model_list[collection_id] = _PageIndexModel

        return _PageIndexModel


class PageIndexHelper:
    columns = [
        "id",
        "page_id" "sha1",
        "text",
        "text_len",
        "markdown",
        "markdown_len",
        "embedding",
        "temp_doc_session_id",
    ]

    def __init__(self, dbs: DatabaseService, collection_id: int):
        self.dbs = dbs
        self.collection_id = collection_id
        self.table_name = "embedding_search_page_index_%s" % str(collection_id)
        self.initialized = False
        self.table_initialized = False

    """
    Initialize table
    """

    async def __aenter__(self):
        if self.initialized:
            return

        self.dbpool = self.dbs.pool.acquire()
        self.dbi = await self.dbpool.__aenter__()

        await register_vector(self.dbi)

        self.create_session = self.dbs.create_session
        self.session = self.dbs.create_session()
        await self.session.__aenter__()

        self.orm = create_page_index_model(self.collection_id)

        self.initialized = True

        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.dbpool.__aexit__(exc_type, exc, tb)
        await self.session.__aexit__(exc_type, exc, tb)

    async def table_exists(self):
        exists = await self.dbi.fetchval(
            """SELECT EXISTS (
                SELECT 1
                FROM   information_schema.tables
                WHERE  table_schema = 'public'
                AND    table_name = $1
            );""",
            self.table_name,
            column=0,
        )

        return bool(exists)

    async def init_table(self):
        if self.table_initialized:
            return

        # create table if not exists
        if not await self.table_exists():
            async with self.dbs.engine.begin() as conn:
                await conn.run_sync(self.orm.__table__.create)

        self.table_initialized = True

    async def create_embedding_index(self):
        pass
        # await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);"
        #                        .replace("/*_*/", self.table_name))

    def sha1_doc(self, doc: list):
        for item in doc:
            if "sha1" not in item or not item["sha1"]:
                sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest()
                item["sha1"] = sha1

    async def get_indexed_sha1(
        self, page_id: int, with_temporary: bool = True, in_collection: bool = False
    ):
        indexed_sha1_list = []

        stmt = select(self.orm).column(self.orm.sha1)

        if not with_temporary:
            stmt = stmt.where(self.orm.temp_doc_session_id == None)

        if not in_collection:
            stmt = stmt.where(self.orm.page_id == page_id)

        ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt)

        for row in ret:
            indexed_sha1_list.append(row.sha1)

        return indexed_sha1_list

    async def get_unindexed_doc(
        self, doc: list, page_id: int, with_temporary: bool = True
    ):
        indexed_sha1_list = await self.get_indexed_sha1(page_id, with_temporary)
        self.sha1_doc(doc)

        should_index = []
        for item in doc:
            if item["sha1"] not in indexed_sha1_list:
                should_index.append(item)

        return should_index

    async def remove_outdated_doc(self, doc: list, page_id: int):
        await self.clear_temp(page_id=page_id)

        indexed_sha1_list = await self.get_indexed_sha1(page_id, False)
        self.sha1_doc(doc)

        doc_sha1_list = [item["sha1"] for item in doc]

        should_remove = []
        for sha1 in indexed_sha1_list:
            if sha1 not in doc_sha1_list:
                should_remove.append(sha1)

        if len(should_remove) > 0:
            await self.dbi.execute(
                "DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)"
                % (self.table_name),
                page_id,
                should_remove,
            )

    async def index_doc(self, doc: list, page_id: int):
        need_create_index = False

        indexed_persist_sha1_list = []
        indexed_temp_sha1_list = []

        ret = await self.dbi.fetch(
            "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1"
            % (self.table_name),
            page_id,
        )
        for row in ret:
            if row[1]:
                indexed_temp_sha1_list.append(row[0])
            else:
                indexed_persist_sha1_list.append(row[0])

        # Create index when no indexed document
        if len(indexed_persist_sha1_list) == 0:
            need_create_index = True

        self.sha1_doc(doc)

        doc_sha1_list = []

        should_index = []
        should_persist = []
        should_remove = []
        for item in doc:
            doc_sha1_list.append(item["sha1"])

            if item["sha1"] in indexed_temp_sha1_list:
                should_persist.append(item["sha1"])
            elif item["sha1"] not in indexed_persist_sha1_list:
                should_index.append(item)

        for sha1 in indexed_persist_sha1_list:
            if sha1 not in doc_sha1_list:
                should_remove.append(sha1)

        if len(should_index) > 0:
            await self.dbi.executemany(
                """INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
                VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);"""
                % (self.table_name),
                [
                    (
                        item["sha1"],
                        page_id,
                        item["text"],
                        len(item["text"]),
                        item["markdown"],
                        len(item["markdown"]),
                        item["embedding"],
                    )
                    for item in should_index
                ],
            )

        if len(should_persist) > 0:
            await self.dbi.executemany(
                "UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2"
                % (self.table_name),
                [(page_id, sha1) for sha1 in should_persist],
            )

        if need_create_index:
            await self.create_embedding_index()

    """
    Add temporary document to the index
    """

    async def index_temp_doc(self, doc: list, temp_doc_session_id: int):
        indexed_sha1_list = []
        indexed_temp_sha1_list = []
        doc_sha1_list = []

        sql = (
            "SELECT sha1, temp_doc_session_id FROM %s WHERE (temp_doc_session_id = $1 OR temp_doc_session_id IS NULL)"
            % (self.table_name)
        )
        ret = await self.dbi.fetch(sql, temp_doc_session_id)
        for row in ret:
            indexed_sha1_list.append(row[0])
            if row[1]:
                indexed_temp_sha1_list.append(row[0])

        self.sha1_doc(doc)

        should_index = []
        should_remove = []

        for item in doc:
            sha1 = item["sha1"]
            doc_sha1_list.append(sha1)
            if sha1 not in indexed_sha1_list:
                should_index.append(item)

        for sha1 in indexed_temp_sha1_list:
            if sha1 not in doc_sha1_list:
                should_remove.append(sha1)

        if len(should_index) > 0:
            await self.dbi.executemany(
                """INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
                VALUES ($1, $2, $3, $4, $5, $6, $7, $8);"""
                % (self.table_name),
                [
                    (
                        item["sha1"],
                        item["page_id"],
                        item["text"],
                        len(item["text"]),
                        item["markdown"],
                        len(item["markdown"]),
                        item["embedding"],
                        temp_doc_session_id,
                    )
                    for item in should_index
                ],
            )

        if len(should_remove) > 0:
            await self.dbi.execute(
                "DELETE FROM %s WHERE temp_doc_session_id = $1 AND sha1 = ANY($2)"
                % (self.table_name),
                temp_doc_session_id,
                should_remove,
            )

    """
    Search for text by consine similary
    """

    async def search_text_embedding(
        self,
        embedding: np.ndarray,
        in_collection: bool = False,
        limit: int = 10,
        page_id: Optional[int] = None,
    ):
        if in_collection:
            return await self.dbi.fetch(
                """SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
                FROM %s
                ORDER BY distance ASC
                LIMIT %d"""
                % (self.table_name, limit),
                embedding,
            )
        else:
            return await self.dbi.fetch(
                """SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
                FROM %s
                WHERE page_id = $2
                ORDER BY distance ASC
                LIMIT %d"""
                % (self.table_name, limit),
                embedding,
                page_id,
            )

    """
    Clear temporary index
    """

    async def clear_temp(
        self, in_collection: bool = False, temp_doc_session_id: int = None, page_id: Optional[int] = None
    ):
        sql = "DELETE FROM %s" % (self.table_name)

        where = []
        params = []

        if not in_collection and page_id:
            params.append(page_id)
            where.append("page_id = $%d" % len(params))

        if temp_doc_session_id:
            params.append(temp_doc_session_id)
            where.append("temp_doc_session_id = $%d" % len(params))
        else:
            where.append("temp_doc_session_id IS NOT NULL")

        if len(where) > 0:
            sql += " WHERE " + (" AND ".join(where))

        await self.dbi.execute(sql, *params)

    async def remove_by_page_id(self, page_id: int):
        stmt = delete(self.orm).where(self.orm.page_id == page_id)
        await self.session.execute(stmt)
        await self.session.commit()