from __future__ import annotations
from typing import Optional, TypedDict

import sqlalchemy
from libs.config import Config
from server.model.embedding_search.title_collection import (
    TitleCollectionHelper,
    TitleCollectionModel,
)
from server.model.embedding_search.title_index import TitleIndexHelper, TitleIndexModel
from server.model.embedding_search.page_index import PageIndexHelper
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
from service.text_embedding import TextEmbeddingService
from service.tiktoken import TikTokenService
from utils.wiki import getWikiSentences


class EmbeddingRunningException(Exception):
    pass


class EmbeddingSearchArgs(TypedDict):
    limit: Optional[int]
    in_collection: Optional[bool]
    distance_limit: Optional[float]


class EmbeddingSearchService:
    indexing_page_ids: list[int] = []

    def __init__(self, dbs: DatabaseService, title: str):
        self.dbs = dbs

        self.title = title
        self.base_title = title.split("/")[0]

        self.title_index_helper = TitleIndexHelper(dbs)
        self.title_collection_helper = TitleCollectionHelper(dbs)
        self.page_index: PageIndexHelper = None

        self.text_embedding: TextEmbeddingService = None
        self.tiktoken: TikTokenService = None

        self.mwapi = MediaWikiApi.create()

        self.page_id: Optional[int] = None
        self.collection_id: Optional[int] = None

        self.title_index: Optional[TitleIndexModel] = None
        self.collection_info: Optional[TitleCollectionModel] = None

        self.page_info: dict = None
        self.unindexed_docs: list = None

    async def __aenter__(self):
        self.text_embedding = await TextEmbeddingService.create()
        self.tiktoken = await TikTokenService.create()

        await self.title_index_helper.__aenter__()
        await self.title_collection_helper.__aenter__()

        self.title_index = await self.title_index_helper.find_by_title(self.title)
        if self.title_index is None:
            # Title may changed, get page info from page_id
            await self.load_page_info()
            self.title_index = await self.title_index_helper.find_by_page_id(
                self.page_info["pageid"]
            )
            self.page_id = self.page_info["pageid"]
        else:
            self.page_id = self.title_index.page_id
            self.collection_id = self.title_index.collection_id
            self.page_index = PageIndexHelper(self.dbs, self.collection_id)
            await self.page_index.__aenter__()

        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.title_index_helper.__aexit__(exc_type, exc, tb)
        await self.title_collection_helper.__aexit__(exc_type, exc, tb)

        if self.page_index is not None:
            await self.page_index.__aexit__(exc_type, exc, tb)

    async def page_index_exists(self, check_table=True):
        if check_table:
            return self.page_index and await self.page_index.table_exists()
        else:
            return self.page_index is not None

    async def load_page_info(self, reload=False):
        if self.page_info is None or reload:
            self.page_info = await self.mwapi.get_page_info(self.title)

    async def should_update_page_index(self, remote_update=False):
        if not remote_update:
            if self.title_index is None:
                return True
            return self.title_index.indexed_rev_id != self.title_index.latest_rev_id
        else:
            await self.load_page_info()

            if (
                self.title_index is not None
                and await self.page_index_exists()
                and self.title_index.indexed_rev_id == self.page_info["lastrevid"]
            ):
                # Not changed
                return False

            return True

    async def update_title_index(self, remote_update=False):
        if not await self.should_update_page_index(remote_update):
            return False

        await self.load_page_info()

        self.page_id = self.page_info["pageid"]

        if self.page_id in self.indexing_page_ids:
            raise EmbeddingRunningException("Page index is running now")

        self.title: str = self.page_info["title"]
        self.base_title = self.title.split("/")[0]

        # Find collection by base title
        self.collection_info = await self.title_collection_helper.find_by_title(
            self.base_title
        )
        if self.collection_info is None:
            # Create collection
            self.collection_info = await self.title_collection_helper.add(
                self.base_title
            )
            if self.collection_info is None:
                raise Exception("Failed to create title collection")
        self.collection_id = self.collection_info.id

        if self.title_index == None:
            # Create title record
            self.title_index = TitleIndexModel(
                title=self.page_info["title"],
                page_id=self.page_id,
                indexed_rev_id=None,
                latest_rev_id=self.page_info["lastrevid"],
                collection_id=self.collection_id,
                embedding=None,
            )
            self.title_index = await self.title_index_helper.add(self.title_index)
            if self.title_index is None:
                raise Exception("Failed to create title index")
        else:
            self.title_index.latest_rev_id = self.page_info["lastrevid"]
            # Title changed, remove embedding
            # Title sha1 will be updated by model helper
            if self.title_index.title != self.page_info["title"]:
                self.title_index.title = self.page_info["title"]
                self.title_index.embedding = None

            # Collection changed, remove old index
            if self.collection_id != self.title_index.collection_id:
                async with PageIndexHelper(self.dbs, self.title_index.collection_id) as old_page_index:
                    await old_page_index.init_table()
                    await old_page_index.remove_by_page_id(self.page_id)
                self.title_index.collection_id = self.collection_id

            await self.title_index_helper.update(self.title_index)

        # Update collection main page id
        if (
            self.title == self.collection_info.title
            and self.page_id != self.collection_info.page_id
        ):
            await self.title_collection_helper.set_main_page_id(
                self.base_title, self.page_id
            )

        if self.page_index:
            await self.page_index.__aexit__(None, None, None)
        self.page_index = PageIndexHelper(self.dbs, self.collection_id)
        await self.page_index.__aenter__()
        await self.page_index.init_table()

    async def prepare_update_index(self):
        await self.update_title_index()

        page_content = await self.mwapi.parse_page(self.title)

        self.sentences = getWikiSentences(page_content)

        self.unindexed_docs = await self.page_index.get_unindexed_doc(
            self.sentences, self.page_id, with_temporary=False
        )

        return True

    async def get_unindexed_tokens(self):
        if self.unindexed_docs is None:
            return 0
        else:
            tokens = 0
            for doc in self.unindexed_docs:
                if "text" in doc:
                    tokens += await self.tiktoken.get_tokens(doc["text"])

            return tokens

    async def update_page_index(self, on_progress=None):
        if self.unindexed_docs is None:
            return False

        chunk_limit = 500

        chunk_len = 0
        doc_chunk = []
        total_token_usage = 0
        processed_len = 0

        async def on_embedding_progress(current, length):
            nonlocal processed_len

            indexed_docs = processed_len + current

            if on_progress is not None:
                await on_progress(indexed_docs, len(self.unindexed_docs))

        async def embedding_doc(doc_chunk):
            (doc_chunk, token_usage) = await self.text_embedding.get_embeddings(
                doc_chunk, on_embedding_progress
            )
            await self.page_index.index_doc(doc_chunk, self.page_id)

            return token_usage

        if len(self.unindexed_docs) > 0:
            if on_progress is not None:
                await on_progress(0, len(self.unindexed_docs))

            for doc in self.unindexed_docs:
                chunk_len += len(doc)

                if chunk_len > chunk_limit:
                    total_token_usage += await embedding_doc(doc_chunk)
                    processed_len += len(doc_chunk)
                    if on_progress is not None:
                        await on_progress(processed_len, len(self.unindexed_docs))

                    doc_chunk = []
                    chunk_len = len(doc)

                doc_chunk.append(doc)

            if len(doc_chunk) > 0:
                total_token_usage += await embedding_doc(doc_chunk)
            if on_progress is not None:
                await on_progress(len(self.unindexed_docs), len(self.unindexed_docs))

            await self.page_index.remove_outdated_doc(self.sentences, self.page_id)

        # Update database
        # This task may take a long time, refresh model to retrieve latest data
        self.title_index = await self.title_index_helper.refresh(self.title_index)

        self.title_index.indexed_rev_id = self.page_info["lastrevid"]

        # Update title embedding
        if await self.title_index.awaitable_attrs.embedding is None:
            doc_chunk = [{"text": self.title}]
            (doc_chunk, token_usage) = await self.text_embedding.get_embeddings(doc_chunk)
            total_token_usage += token_usage

            embedding = doc_chunk[0]["embedding"]
            self.title_index.embedding = embedding

        await self.title_index_helper.update(self.title_index)

        return total_token_usage

    async def search(
        self,
        query: str,
        limit: int = 10,
        in_collection: bool = False,
        distance_limit: Optional[float] = None,
    ):
        if limit == 0:
            return [], 0

        if distance_limit is None:
            distance_limit = Config.get("embedding.default_distance_limit")
        
        if self.page_index is None:
            raise Exception("Page index is not initialized")

        query_doc = [{"text": query}]
        query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc)
        query_embedding = query_doc[0]["embedding"]

        print(query_embedding)

        if query_embedding is None:
            return [], token_usage

        res = await self.page_index.search_text_embedding(
            query_embedding, in_collection, limit, self.page_id
        )
        print(res)
        if res:
            filtered = []
            for one in res:
                if one["distance"] < distance_limit:
                    filtered.append(dict(one))
            return filtered, token_usage
        else:
            return res, token_usage