from __future__ import annotations
import sys
import time
import traceback
from libs.config import Config
from server.model.chat_complete.bot_persona import BotPersonaHelper
from utils.local import noawait
from typing import Optional, Callable, Union
from service.chat_complete import (
    ChatCompleteService,
    ChatCompleteServicePrepareResponse,
    ChatCompleteServiceResponse,
    calculate_point_usage,
)
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
from service.tiktoken import TikTokenService
import utils.web

chat_complete_tasks: dict[str, ChatCompleteTask] = {}


class ChatCompleteTask:
    @staticmethod
    def get_by_id(task_id: str) -> Union[ChatCompleteTask, None]:
        return chat_complete_tasks.get(task_id)

    def __init__(
        self, dbs: DatabaseService, user_id: int, page_title: str, is_system=False
    ):
        self.task_id = utils.web.generate_uuid()
        self.on_message: list[Callable] = []
        self.on_finished: list[Callable] = []
        self.on_error: list[Callable] = []
        self.chunks: list[str] = []

        self.chat_complete_service: ChatCompleteService
        self.chat_complete: ChatCompleteService

        self.dbs = dbs
        self.user_id = user_id
        self.page_title = page_title
        self.is_system = is_system

        self.transatcion_id: Optional[str] = None
        self.point_usage: int = 0

        self.is_finished = False
        self.finished_time: Optional[float] = None
        self.result: Optional[ChatCompleteServiceResponse] = None
        self.error: Optional[Exception] = None

    async def init(
        self,
        question: str,
        bot_id: str,
        conversation_id: Optional[str] = None,
        edit_message_id: Optional[str] = None,
        embedding_search: Optional[EmbeddingSearchArgs] = None,
    ) -> ChatCompleteServicePrepareResponse:
        self.tiktoken = await TikTokenService.create()

        self.mwapi = MediaWikiApi.create()

        self.chat_complete_service = ChatCompleteService(self.dbs, self.page_title)
        self.chat_complete = await self.chat_complete_service.__aenter__()

        try:
            if await self.chat_complete.page_index_exists():
                question_tokens = await self.tiktoken.get_tokens(question)

                extract_limit = embedding_search["limit"] or 10

                estimated_extract_tokens_per_doc = Config.get("estimated_extract_tokens_per_doc", 50, int)

                predict_tokens = extract_limit + estimated_extract_tokens_per_doc

                async with BotPersonaHelper(self.dbs) as bot_persona_helper:
                    bot_persona = await bot_persona_helper.find_by_bot_id(bot_id)

                self.point_usage: int = calculate_point_usage(predict_tokens, bot_persona.cost_fixed,
                                                            bot_persona.cost_fixed_tokens, bot_persona.cost_per_token)

                self.transatcion_id: Optional[str] = None
                if not self.is_system:
                    usage_res = await self.mwapi.ai_toolbox_start_transaction(
                        self.user_id, "chatcomplete",
                        bot_id=bot_id,
                        tokens=predict_tokens,
                        point_usage=self.point_usage
                    )
                    self.transatcion_id = usage_res["transaction_id"]

                res = await self.chat_complete.prepare_chat_complete(
                    question,
                    conversation_id=conversation_id,
                    user_id=self.user_id,
                    question_tokens=question_tokens,
                    edit_message_id=edit_message_id,
                    bot_id=bot_id,
                    embedding_search=embedding_search,
                )

                return res
            else:
                raise MediaWikiPageNotFoundException(
                    "Page %s not found." % self.page_title
                )
        except Exception as e:
            await self.end()
            raise e

    async def _on_message(self, delta_message: str):
        self.chunks.append(delta_message)

        for callback in self.on_message:
            try:
                await callback(delta_message)
            except Exception as e:
                print(
                    "Error while processing on_message callback: %s" % e,
                    file=sys.stderr,
                )
                traceback.print_exc()

    async def _on_finished(self):
        for callback in self.on_finished:
            try:
                await callback(self.result)
            except Exception as e:
                print(
                    "Error while processing on_finished callback: %s" % e,
                    file=sys.stderr,
                )
                traceback.print_exc()

    async def _on_error(self, err: Exception):
        self.error = err
        for callback in self.on_error:
            try:
                await callback(err)
            except Exception as e:
                print(
                    "Error while processing on_error callback: %s" % e, file=sys.stderr
                )
                traceback.print_exc()

    async def run(self) -> ChatCompleteServiceResponse:
        chat_complete_tasks[self.task_id] = self
        try:
            chat_res = await self.chat_complete.finish_chat_complete(self._on_message)

            await self.chat_complete.set_latest_point_usage(self.point_usage)

            self.result = chat_res

            if self.transatcion_id:
                await self.mwapi.ai_toolbox_end_transaction(
                    self.transatcion_id, self.point_usage  # TODO: 根据实际使用Tokens扣除积分
                )

            await self._on_finished()
        except Exception as e:
            err_msg = f"Error while processing chat complete request: {e}"

            print(err_msg, file=sys.stderr)
            traceback.print_exc()

            if self.transatcion_id:
                await self.mwapi.ai_toolbox_cancel_transaction(
                    self.transatcion_id, error=err_msg
                )

            await self._on_error(e)
        finally:
            await self.end()

    async def end(self):
        await self.chat_complete_service.__aexit__(None, None, None)
        if self.task_id in chat_complete_tasks:
            del chat_complete_tasks[self.task_id]
        self.is_finished = True
        self.finished_time = time.time()


TASK_EXPIRE_TIME = 60 * 10


async def chat_complete_task_gc():
    now = time.time()
    for task_id in chat_complete_tasks.keys():
        task = chat_complete_tasks[task_id]
        if (
            task.is_finished
            and task.finished_time is not None
            and now > task.finished_time + TASK_EXPIRE_TIME
        ):
            del chat_complete_tasks[task_id]


noawait.add_timer(chat_complete_task_gc, 60)