from __future__ import annotations
import asyncio
import time
import traceback
from server.controller.task.ChatCompleteTask import ChatCompleteTask
from server.model.base import clone_model
from server.model.chat_complete.bot_persona import BotPersonaHelper
from server.model.toolbox_ui.conversation import ConversationHelper
from utils.local import noawait
from aiohttp import web
from server.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException
from service.tiktoken import TikTokenService
import utils.web

class ChatComplete:
    @staticmethod
    @utils.web.token_auth
    async def get_conversation_chunk_list(request: web.Request):
        params = await utils.web.get_param(request, {
            "user_id": {
                "required": False,
                "type": int
            },
            "id": {
                "required": True,
                "type": int
            }
        })

        if request.get("caller") == "user":
            user_id = request.get("user")
        else:
            user_id = params.get("user_id")

        conversation_id = params.get("id")

        db = await DatabaseService.create(request.app)

        async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
            conversation_info = await conversation_helper.find_by_id(conversation_id)

            if conversation_info is None:
                return await utils.web.api_response(-1, error={
                    "code": "conversation-not-found",
                    "message": "Conversation not found."
                }, http_status=404, request=request)

            if conversation_info.user_id != user_id:
                return await utils.web.api_response(-1, error={
                    "code": "permission-denied",
                    "message": "Permission denied."
                }, http_status=403, request=request)

            conversation_chunk_result = await conversation_chunk_helper.get_chunk_id_list(conversation_id)

            conversation_chunk_list = []

            for result in conversation_chunk_result:
                conversation_chunk_list.append(result)

        return await utils.web.api_response(1, conversation_chunk_list, request=request)

    @staticmethod
    @utils.web.token_auth
    async def get_conversation_chunk(request: web.Request):
        params = await utils.web.get_param(request, {
            "user_id": {
                "required": False,
                "type": int,
            },
            "chunk_id": {
                "required": True,
                "type": int,
            },
        })

        if request.get("caller") == "user":
            user_id = request.get("user")
        else:
            user_id = params.get("user_id")

        chunk_id = params.get("chunk_id")

        db = await DatabaseService.create(request.app)
        async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
            chunk_info = await conversation_chunk_helper.find_by_id(chunk_id)
            if chunk_info is None:
                return await utils.web.api_response(-1, error={
                    "code": "conversation-chunk-not-found",
                    "message": "Conversation chunk not found."
                }, http_status=404, request=request)

            conversation_info = await conversation_helper.find_by_id(chunk_info.conversation_id)
            
            if conversation_info is not None and conversation_info.user_id != user_id:
                return await utils.web.api_response(-1, error={
                    "code": "permission-denied",
                    "message": "Permission denied."
                }, http_status=403, request=request)

        chunk_dict = {
            "id": chunk_info.id,
            "conversation_id": chunk_info.conversation_id,
            "message_data": chunk_info.message_data,
            "tokens": chunk_info.tokens,
            "updated_at": chunk_info.updated_at,
        }

        return await utils.web.api_response(1, chunk_dict, request=request)
    
    @staticmethod
    @utils.web.token_auth
    async def fork_conversation(request: web.Request):
        params = await utils.web.get_param(request, {
            "user_id": {
                "required": False,
                "type": int,
            },
            "id": {
                "required": True,
                "type": int,
            },
            "message_id": {
                "required": False,
                "type": str
            },
            "new_title": {
                "required": False,
                "type": str
            }
        })

        if request.get("caller") == "user":
            user_id = request.get("user")
        else:
            user_id = params.get("user_id")

        conversation_id: int = params.get("id")
        packed_message_id: str = params.get("message_id")
        new_title = params.get("new_title")

        if packed_message_id is not None:
            (chunk_id, msg_id) = packed_message_id.split(",")
            chunk_id = int(chunk_id)
        else:
            chunk_id = None
            msg_id = None

        db = await DatabaseService.create(request.app)
        async with ConversationHelper(db) as conversation_helper, ConversationChunkHelper(db) as conversation_chunk_helper:
            conversation_info = await conversation_helper.find_by_id(conversation_id)
            if conversation_info is None:
                return await utils.web.api_response(-1, error={
                    "code": "conversation-not-found",
                    "message": "Conversation not found."
                }, http_status=404, request=request)

            if conversation_info.user_id != user_id:
                return await utils.web.api_response(-1, error={
                    "code": "permission-denied",
                    "message": "Permission denied."
                }, http_status=403, request=request)
            
            # Clone selected chunk
            if chunk_id is not None:
                chunk_info = await conversation_chunk_helper.find_by_id(chunk_id)
                if chunk_info is None or chunk_info.conversation_id != conversation_id:
                    return await utils.web.api_response(-1, error={
                        "code": "conversation-chunk-not-found",
                        "message": "Conversation chunk not found."
                    }, http_status=404, request=request)
            else:
                chunk_info = await conversation_chunk_helper.get_newest_chunk(conversation_id)

            new_conversation: ConversationModel = clone_model(conversation_info)

            if new_title is not None:
                new_conversation.title = new_title
            
            new_conversation = await conversation_helper.add(new_conversation)

            if chunk_info is not None:
                new_chunk: ConversationChunkModel = clone_model(chunk_info)
                new_chunk.conversation_id = new_conversation.id

                if msg_id is not None:
                    # Remove message after selected message
                    split_message_pos = None
                    for i in range(0, len(new_chunk.message_data)):
                        msg_data: dict = new_chunk.message_data[i]
                        if msg_data.get("id") == msg_id:
                            split_message_pos = i
                            break
                        
                    new_chunk.message_data = new_chunk.message_data[0:split_message_pos + 1]

                new_chunk.message_data.insert(0, {
                    "id": utils.web.generate_uuid(),
                    "role": "notice",
                    "type": "forked",
                    "data": {
                        "original_conversation_id": conversation_info.id,
                        "original_title": conversation_info.title,
                    }
                })

                # Update conversation description
                last_assistant_message = None
                for msg in new_chunk.message_data:
                    if msg["role"] == "assistant":
                        last_assistant_message = msg

                if last_assistant_message is not None:
                    new_conversation.description = last_assistant_message["content"][0:150]
                    conversation_helper.update(new_conversation)

                new_chunk = await conversation_chunk_helper.add(new_chunk)

        return await utils.web.api_response(1, {
            "conversation_id": new_conversation.id,
        }, request=request)
    
    @staticmethod
    @utils.web.token_auth
    async def get_tokens(request: web.Request):
        params = await utils.web.get_param(request, {
            "question": {
                "type": str,
                "required": True
            }
        })

        question = params.get("question")

        tiktoken = await TikTokenService.create()
        tokens = await tiktoken.get_tokens(question)

        return await utils.web.api_response(1, {"tokens": tokens}, request=request)

    @staticmethod
    @utils.web.token_auth
    async def get_point_cost(request: web.Request):
        params = await utils.web.get_param(request, {
            "question": {
                "type": str,
                "required": True,
            },
            "extract_limit": {
                "type": int,
                "required": False,
                "default": 10,
            },
        })
        
        user_id = request.get("user")
        caller = request.get("caller")

        question = params.get("question")
        extract_limit = params.get("extract_limit")

        tiktoken = await TikTokenService.create()
        mwapi = MediaWikiApi.create()

        tokens = await tiktoken.get_tokens(question)

        try:
            res = await mwapi.ai_toolbox_get_point_cost(user_id, "chatcomplete", tokens, extract_limit)
            return await utils.web.api_response(1, {
                "point_cost": res["point_cost"],
                "tokens": tokens,
            }, request=request)
        except Exception as e:
            err_msg = f"Error while get chat complete point cost: {e}"
            traceback.print_exc()
            
            return await utils.web.api_response(-1, error={
                "code": "internal-server-error",
                "message": err_msg
            }, http_status=500, request=request)
        
    @staticmethod
    @utils.web.token_auth
    async def get_persona_list(request: web.Request):
        params = await utils.web.get_param(request, {
            "category_id": {
                "type": int,
                "required": False,
            },
            "page": {
                "type": int,
                "required": False,
                "default": 1,
            }
        })

        category_id = params.get("category_id")
        page = params.get("page")

        db = await DatabaseService.create(request.app)
        async with BotPersonaHelper(db) as bot_persona_helper:
            persona_list = await bot_persona_helper.get_list(page=page, category_id=category_id)
            page_count = await bot_persona_helper.get_page_count(category_id=category_id)

        persona_data_list = []
        for persona in persona_list:
            persona_data_list.append({
                "id": persona.id,
                "bot_id": persona.bot_id,
                "bot_name": persona.bot_name,
                "bot_avatar": persona.bot_avatar,
                "bot_description": persona.bot_description,
                "updated_at": persona.updated_at,
            })

        return await utils.web.api_response(1, {
            "list": persona_data_list,
            "page_count": page_count,
        }, request=request)
    
    @staticmethod
    @utils.web.token_auth
    async def get_persona_info(request: web.Request):
        params = await utils.web.get_param(request, {
            "id": {
                "type": int,
            },
            "bot_id": {
                "type": str,
            }
        })

        persona_id = params.get("id")
        bot_id = params.get("bot_id")

        db = await DatabaseService.create(request.app)
        async with BotPersonaHelper(db) as bot_persona_helper:
            if persona_id is not None:
                persona_info = await bot_persona_helper.find_by_id(persona_id)
            elif bot_id is not None:
                persona_info = await bot_persona_helper.find_by_bot_id(bot_id)
            else:
                return await utils.web.api_response(-1, error={
                    "code": "invalid-params",
                    "message": "Invalid params. Please specify id or bot_id."
                }, http_status=400, request=request)
            
            persona_info_res = {}
            for key, value in persona_info.__dict__.items():
                if not key.startswith("_"):
                    persona_info_res[key] = value

        return await utils.web.api_response(1, persona_info_res, request=request)

    @staticmethod
    @utils.web.token_auth
    async def start_chat_complete(request: web.Request):
        params = await utils.web.get_param(request, {
            "title": {
                "type": str,
                "required": True,
            },
            "question": {
                "type": str,
                "required": True,
            },
            "conversation_id": {
                "type": int,
                "required": False,
            },
            "bot_id": {
                "type": str,
                "required": False,
            },
            "extract_limit": {
                "type": int,
                "required": False,
                "default": 10,
            },
            "in_collection": {
                "type": bool,
                "required": False,
                "default": False,
            },
            "edit_message_id": {
                "type": str,
                "required": False,
            },
        })

        user_id = request.get("user")
        caller = request.get("caller")

        page_title = params.get("title")
        question = params.get("question")
        conversation_id = params.get("conversation_id")
        bot_id = params.get("bot_id")

        extract_limit = params.get("extract_limit")
        in_collection = params.get("in_collection")

        edit_message_id = params.get("edit_message_id")

        dbs = await DatabaseService.create(request.app)

        try:
            chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user")
            init_res = await chat_complete_task.init(question, conversation_id=conversation_id, edit_message_id=edit_message_id,
                bot_id=bot_id,
                embedding_search={
                    "limit": extract_limit,
                    "in_collection": in_collection,
                })

            noawait.add_task(chat_complete_task.run())
            
            return await utils.web.api_response(1, data={
                "conversation_id": init_res["conversation_id"],
                "chunk_id": init_res["chunk_id"],
                "question_tokens": init_res["question_tokens"],
                "extract_doc": init_res["extract_doc"],
                "task_id": chat_complete_task.task_id,
            }, request=request)
        except MediaWikiPageNotFoundException as e:
            error_msg = "Page \"%s\" not found." % page_title
            return await utils.web.api_response(-1, error={
                "code": "page-not-found",
                "title": page_title,
                "message": error_msg
            }, http_status=404, request=request)
        except MediaWikiUserNoEnoughPointsException as e:
            error_msg = "Does not have enough points." % user_id
            return await utils.web.api_response(-1, error={
                "code": "no-enough-points",
                "message": error_msg
            }, http_status=403, request=request)
        except ChatCompleteQuestionTooLongException as e:
            error_msg = "Question too long."
            return await utils.web.api_response(-1, error={
                "code": "question-too-long",
                "limit": e.tokens_limit,
                "current": e.tokens_current,
                "message": error_msg
            }, http_status=400, request=request)
        except Exception as e:
            err_msg = f"Error while processing chat complete request: {e}"
            traceback.print_exc()
            
            return await utils.web.api_response(-1, error={
                "code": "internal-server-error",
                "message": err_msg
            }, http_status=500, request=request)
        
    @staticmethod
    @utils.web.token_auth
    async def chat_complete_stream(request: web.Request):
        if not utils.web.is_websocket(request):
            return await utils.web.api_response(-1, error={
                "code": "websocket-required",
                "message": "This API only accept websocket connection."
            }, http_status=400, request=request)
        
        params = await utils.web.get_param(request, {
            "task_id": {
                "type": str,
                "required": True,
            }
        })

        ws = web.WebSocketResponse()
        await ws.prepare(request)

        task_id = params.get("task_id")

        task = ChatCompleteTask.get_by_id(task_id)
        if task is None:
            await ws.send_json({
                'event': 'error',
                'status': -1,
                'message': "Task not found.",
                'error': {
                    'code': "task-not-found",
                    'info': "Task not found.",
                },
            })
            return

        if request.get("caller") == "user":
            user_id = request.get("user")
            if task.user_id != user_id:
                await ws.send_json({
                    'event': 'error',
                    'status': -1,
                    'message': "Permission denied.",
                    'error': {
                        'code': "permission-denied",
                        'info': "Permission denied.",
                    },
                })
                return
        
        if task.is_finished:
            if task.error is not None:
                await ws.send_json({
                    'event': 'error',
                    'status': -1,
                    'message': str(task.error),
                    'error': {
                        'code': "internal-server-error",
                        'info': str(task.error),
                    },
                })
                await ws.close()
            elif task.result is not None:
                await ws.send_json({
                    'event': 'connected',
                    'status': 1,
                    'outputed_message': "".join(task.chunks),
                })
                await ws.send_json({
                    'event': 'finished',
                    'status': 1,
                    'result': task.result
                })
                await ws.close()
        else:
            async def on_closed():
                task.on_message.remove(on_message)
                task.on_finished.remove(on_finished)
                task.on_error.remove(on_error)

            async def on_message(delta_message: str):
                try:
                    await ws.send_str("+" + delta_message)
                except ConnectionResetError:
                    await on_closed()

            async def on_finished(result: ChatCompleteServiceResponse):
                try:
                    ignored_keys = ["message"]
                    response_result = {
                        "point_cost": task.point_cost,
                    }
                    for k, v in result.items():
                        if k not in ignored_keys:
                            response_result[k] = v
                    await ws.send_json({
                        'event': 'finished',
                        'status': 1,
                        'result': response_result
                    })

                    await ws.close()
                except ConnectionResetError:
                    await on_closed()

            async def on_error(err: Exception):
                try:
                    await ws.send_json({
                        'event': 'error',
                        'status': -1,
                        'message': str(err),
                        'error': {
                            'code': "internal-server-error",
                            'info': str(err),
                        },
                    })

                    await ws.close()
                except ConnectionResetError:
                    await on_closed()

            task.on_message.append(on_message)
            task.on_finished.append(on_finished)
            task.on_error.append(on_error)

            # Send received message            
            await ws.send_json({
                'event': 'connected',
                'status': 1,
                'outputed_message': "".join(task.chunks),
            })

            last_heartbeat = time.time()
            while True:
                current_time = time.time()
                if ws.closed or task.is_finished:
                    await on_closed()
                    break
                
                if current_time - last_heartbeat >= 15:
                    try:
                        await ws.ping('{"event":"ping"}'.encode('utf-8'))
                        last_heartbeat = current_time
                    except ConnectionResetError:
                        await on_closed()
                        break
                await asyncio.sleep(0.1)