from __future__ import annotations
import time
import traceback
from typing import Optional, Tuple, TypedDict

import sqlalchemy
from api.model.chat_complete.bot_persona import BotPersonaHelper
from api.model.chat_complete.conversation import (
    ConversationChunkHelper,
    ConversationChunkModel,
)
import sys
from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel

from config import Config
import utils.config, utils.web

from aiohttp import web
from api.model.embedding_search.title_collection import TitleCollectionModel
from sqlalchemy.orm.attributes import flag_modified

from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs, EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService

class ChatCompleteQuestionTooLongException(Exception):
    def __init__(self, tokens_limit: int, tokens_current: int):
        super().__init__(f"Question too long: {tokens_current} > {tokens_limit}")
        self.tokens_limit = tokens_limit
        self.tokens_current = tokens_current

class ChatCompleteServicePrepareResponse(TypedDict):
    extract_doc: list
    question_tokens: int
    conversation_id: int
    chunk_id: int

class ChatCompleteServiceResponse(TypedDict):
    message: str
    message_tokens: int
    total_tokens: int
    finish_reason: str
    question_message_id: str
    response_message_id: str
    delta_data: dict

class ChatCompleteService:
    def __init__(self, dbs: DatabaseService, title: str):
        self.dbs = dbs

        self.title = title
        self.base_title = title.split("/")[0]

        self.embedding_search = EmbeddingSearchService(dbs, title)
        self.conversation_helper = ConversationHelper(dbs)
        self.conversation_chunk_helper = ConversationChunkHelper(dbs)
        self.bot_persona_helper = BotPersonaHelper(dbs)


        self.conversation_info: Optional[ConversationModel] = None
        self.conversation_chunk: Optional[ConversationChunkModel] = None

        self.tiktoken: TikTokenService = None

        self.extract_doc: list = None

        self.mwapi = MediaWikiApi.create()
        self.openai_api = OpenAIApi.create()

        self.user_id = 0
        self.question = ""
        self.question_tokens: Optional[int] = None
        self.bot_id: str = ""
        self.conversation_id: Optional[int] = None
        self.conversation_start_time: Optional[int] = None

        self.chat_system_prompt = ""

        self.delta_data = {}

    async def __aenter__(self):
        self.tiktoken = await TikTokenService.create()

        await self.embedding_search.__aenter__()
        await self.conversation_helper.__aenter__()
        await self.conversation_chunk_helper.__aenter__()
        await self.bot_persona_helper.__aenter__()

        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.embedding_search.__aexit__(exc_type, exc, tb)
        await self.conversation_helper.__aexit__(exc_type, exc, tb)
        await self.conversation_chunk_helper.__aexit__(exc_type, exc, tb)
        await self.bot_persona_helper.__aexit__(exc_type, exc, tb)

    async def page_index_exists(self):
        return await self.embedding_search.page_index_exists(False)

    async def get_question_tokens(self, question: str):
        return await self.tiktoken.get_tokens(question)

    async def prepare_chat_complete(
        self,
        question: str,
        conversation_id: Optional[str] = None,
        user_id: Optional[int] = None,
        question_tokens: Optional[int] = None,
        edit_message_id: Optional[str] = None,
        bot_id: Optional[str] = None,
        embedding_search: Optional[EmbeddingSearchArgs] = None,
    ) -> ChatCompleteServicePrepareResponse:
        if user_id is not None:
            user_id = int(user_id)

        self.user_id = user_id
        self.question = question
        self.conversation_start_time = int(time.time())
        self.bot_id = bot_id or None

        self.conversation_info = None
        if conversation_id is not None:
            self.conversation_id = int(conversation_id)
            self.conversation_info = await self.conversation_helper.find_by_id(
                self.conversation_id
            )
        else:
            self.conversation_id = None

        if self.conversation_info is not None:
            if self.conversation_info.user_id != user_id:
                raise web.HTTPUnauthorized()

        if question_tokens is None:
            self.question_tokens = await self.get_question_tokens(question)
        else:
            self.question_tokens = question_tokens
        if self.question_tokens == 0:
            self.question_tokens = len(question) * 3

        max_input_tokens = Config.get("chatcomplete.max_input_tokens", 768, int)
        if (
            self.question_tokens > max_input_tokens
        ):
            # If the question is too long, we need to truncate it
            raise ChatCompleteQuestionTooLongException(max_input_tokens, self.question_tokens)
        
        if self.conversation_info is not None:
            self.bot_id = self.conversation_info.extra.get("bot_id") or "default"

        bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id)

        if bot_persona is None:
            self.bot_id = "default"
            bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id)
        else:
            self.chat_system_prompt = bot_persona.system_prompt
        
        self.conversation_chunk = None
        if self.conversation_info is not None:
            chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id)

            if edit_message_id and "," in edit_message_id:
                (edit_chunk_id, edit_msg_id) = edit_message_id.split(",")
                edit_chunk_id = int(edit_chunk_id)

                # Remove overrided conversation chunks
                start_overrided = False
                should_remove_chunk_ids = []
                for chunk_id in chunk_id_list:
                    if start_overrided:
                        should_remove_chunk_ids.append(chunk_id)
                    else:
                        if chunk_id == edit_chunk_id:
                            start_overrided = True

                if len(should_remove_chunk_ids) > 0:
                    await self.conversation_chunk_helper.remove(should_remove_chunk_ids)

                self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(
                    self.conversation_id
                )
                # Remove outdated message
                edit_message_pos = None
                old_tokens = 0
                for i in range(0, len(self.conversation_chunk.message_data)):
                    msg_data: dict = self.conversation_chunk.message_data[i]
                    if msg_data.get("id") == edit_msg_id:
                        edit_message_pos = i
                        break
                    if "tokens" in msg_data and msg_data["tokens"] is not None:
                        old_tokens += msg_data["tokens"]
                
                if edit_message_pos:
                    self.conversation_chunk.message_data = self.conversation_chunk.message_data[0:edit_message_pos]
                    flag_modified(self.conversation_chunk, "message_data")
                    self.conversation_chunk.tokens = old_tokens
            else:
                self.conversation_chunk = await self.conversation_chunk_helper.get_newest_chunk(
                    self.conversation_id
                )

            # If the conversation is too long, we need to make a summary
            max_memory_tokens = Config.get("chatcomplete.max_memory_tokens", 1280, int)
            if self.conversation_chunk.tokens > max_memory_tokens:
                summary, tokens = await self.make_summary(
                    self.conversation_chunk.message_data
                )
                new_message_log = [
                    {
                        "role": "summary",
                        "content": summary,
                        "tokens": tokens,
                        "time": int(time.time()),
                    }
                ]

                self.conversation_chunk = ConversationChunkModel(
                    conversation_id=self.conversation_id,
                    message_data=new_message_log,
                    tokens=tokens,
                )

                self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk)
        else:
            # 创建新对话

            # 默认聊天记录
            init_message_data = []
            if bot_persona is not None:
                current_time = int(time.time())
                for message in bot_persona.message_log:
                    message["id"] = utils.web.generate_uuid()
                    message["time"] = current_time
                    init_message_data.append(message)

            title_info = self.embedding_search.title_index
            self.conversation_info = ConversationModel(
                user_id=self.user_id,
                module="chatcomplete",
                page_id=title_info.page_id,
                rev_id=title_info.latest_rev_id,
                extra={"bot_id": self.bot_id},
            )
            self.conversation_info = await self.conversation_helper.add(
                self.conversation_info,
            )

            self.conversation_chunk = ConversationChunkModel(
                conversation_id=self.conversation_info.id,
                message_data=init_message_data,
                tokens=0,
            )
            self.conversation_chunk = await self.conversation_chunk_helper.add(
                self.conversation_chunk
            )

        # Extract document from wiki page index
        self.extract_doc = None
        if embedding_search is not None:
            self.extract_doc, token_usage = await self.embedding_search.search(
                question, **embedding_search
            )
            if self.extract_doc is not None:
                self.question_tokens += token_usage

        return ChatCompleteServicePrepareResponse(
            extract_doc=self.extract_doc,
            question_tokens=self.question_tokens,
            conversation_id=self.conversation_info.id,
            chunk_id=self.conversation_chunk.id
        )

    async def finish_chat_complete(
        self, on_message: Optional[callable] = None
    ) -> ChatCompleteServiceResponse:
        delta_data = {}

        message_log = []
        if self.conversation_chunk is not None:
            for message in self.conversation_chunk.message_data:
                if message["role"] in ["user", "assistant"]:
                    message_log.append(
                        {
                            "role": message["role"],
                            "content": message["content"],
                        }
                    )

        if self.extract_doc is not None:
            doc_prompt_content = "\n".join(
                [
                    "%d. %s" % (i + 1, doc["markdown"] or doc["text"])
                    for i, doc in enumerate(self.extract_doc)
                ]
            )

            doc_prompt = utils.config.get_prompt(
                "extracted_doc", "prompt", {"content": doc_prompt_content}
            )
            message_log.append({"role": "user", "content": doc_prompt})

        system_prompt = self.chat_system_prompt
        if system_prompt is None:
            system_prompt = utils.config.get_prompt("default", "system")

        if system_prompt is None:
            raise Exception("System prompt not found.")
        
        system_prompt = utils.config.format_prompt(system_prompt)

        # Start chat complete
        if on_message is not None:
            response = await self.openai_api.chat_complete_stream(
                self.question, system_prompt, message_log, on_message
            )
        else:
            response = await self.openai_api.chat_complete(
                self.question, system_prompt, message_log
            )

        description = response["message"][0:150]

        question_msg_id = utils.web.generate_uuid()
        response_msg_id = utils.web.generate_uuid()

        new_message_data = [
            {
                "id": question_msg_id,
                "role": "user",
                "content": self.question,
                "tokens": self.question_tokens,
                "time": self.conversation_start_time,
            },
            {
                "id": response_msg_id,
                "role": "assistant",
                "content": response["message"],
                "tokens": response["message_tokens"],
                "time": int(time.time()),
            },
        ]

        if self.conversation_info is not None:
            total_token_usage = self.question_tokens + response["message_tokens"]
            # Generate title if not exists
            if self.conversation_info.title is None:
                title = None
                try:
                    title, token_usage = await self.make_title(new_message_data)
                    delta_data["title"] = title
                except Exception as e:
                    print(str(e), file=sys.stderr)
                    traceback.print_exc(file=sys.stderr)

                self.conversation_info.title = title

            # Update conversation info
            self.conversation_info.description = description

            await self.conversation_helper.update(self.conversation_info)
            
            # Update conversation chunk
            self.conversation_chunk.message_data.extend(new_message_data)
            flag_modified(self.conversation_chunk, "message_data")
            self.conversation_chunk.tokens += total_token_usage

            await self.conversation_chunk_helper.update(self.conversation_chunk)

        return ChatCompleteServiceResponse(
            message=response["message"],
            message_tokens=response["message_tokens"],
            total_tokens=response["total_tokens"],
            finish_reason=response["finish_reason"],
            question_message_id=question_msg_id,
            response_message_id=response_msg_id,
            delta_data=delta_data,
        )

    async def set_latest_point_cost(self, point_cost: int) -> bool:
        if self.conversation_chunk is None:
            return False

        if len(self.conversation_chunk.message_data) == 0:
            return False

        for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1):
            if self.conversation_chunk.message_data[i]["role"] == "assistant":
                self.conversation_chunk.message_data[i]["point_cost"] = point_cost
                flag_modified(self.conversation_chunk, "message_data")
                await self.conversation_chunk_helper.update(self.conversation_chunk)

                return True

    async def make_summary(self, message_log_list: list) -> tuple[str, int]:
        chat_log: list[str] = []
        bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)

        for message_data in message_log_list:
            if message_data["role"] == "summary":
                chat_log.append(message_data["content"])
            elif message_data["role"] == "assistant":
                chat_log.append(
                    f'{bot_name}: {message_data["content"]}'
                )
            else:
                chat_log.append(f'User: {message_data["content"]}')

        chat_log_str = "\n".join(chat_log)

        summary_system_prompt = utils.config.get_prompt("make_summary", "system")
        summary_prompt = utils.config.get_prompt(
            "make_summary", "prompt", {"content": chat_log_str}
        )

        response = await self.openai_api.chat_complete(
            summary_prompt, summary_system_prompt
        )

        return response["message"], response["message_tokens"]

    async def make_title(self, message_log_list: list) -> tuple[str, int]:
        chat_log: list[str] = []
        bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)

        for message_data in message_log_list:
            if message_data["role"] == "assistant":
                chat_log.append(
                    f'{bot_name}: {message_data["content"]}'
                )
            elif message_data["role"] == "user":
                chat_log.append(f'User: {message_data["content"]}')

        chat_log_str = "\n".join(chat_log)

        title_system_prompt = utils.config.get_prompt("make_title", "system")
        title_prompt = utils.config.get_prompt(
            "make_title", "prompt", {"content": chat_log_str}
        )

        response = await self.openai_api.chat_complete(
            title_prompt, title_system_prompt
        )
        return response["message"], response["message_tokens"]