import asyncio
import json
import time
import traceback
from local import noawait
from typing import Optional
from aiohttp import WSMsgType, web
from sqlalchemy import select
from api.model.chat_complete.conversation import ConversationModel, ConversationChunkModel
from noawait import NoAwaitPool
from service.chat_complete import ChatCompleteService
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
from service.tiktoken import TikTokenService
import utils.web

class ChatCompleteTaskList:
    def __init__(self, dbs: DatabaseService):
        self.on_message = None
        self.chunks: list[str] = []

    async def run():
        pass

    @noawait.wrap
    async def start(self):
        await self.run()

class ChatComplete:
    @staticmethod
    @utils.web.token_auth
    async def get_conversation_chunk_list(request: web.Request):
        params = await utils.web.get_param(request, {
            "user_id": {
                "required": False,
                "type": int
            },
            "conversation_id": {
                "required": True,
                "type": int
            }
        })

        if request.get("caller") == "user":
            user_id = request.get("user")
        else:
            user_id = params.get("user_id")

        conversation_id = params.get("conversation_id")

        db = await DatabaseService.create(request.app)

        async with db.create_session() as session:
            stmt = select(ConversationModel).where(
                ConversationModel.id == conversation_id)

            conversation_data = await session.scalar(stmt)

            if conversation_data is None:
                return await utils.web.api_response(-1, error={
                    "code": "conversation-not-found",
                    "message": "Conversation not found."
                }, http_status=404, request=request)

            if conversation_data.user_id != user_id:
                return await utils.web.api_response(-1, error={
                    "code": "permission-denied",
                    "message": "Permission denied."
                }, http_status=403, request=request)

            stmt = select(ConversationChunkModel).with_only_columns([ConversationChunkModel.id, ConversationChunkModel.updated_at]) \
                .where(ConversationChunkModel.conversation_id == conversation_id).order_by(ConversationChunkModel.id.asc())

            conversation_chunk_result = await session.scalars(stmt)

            conversation_chunk_list = []

            for result in conversation_chunk_result:
                conversation_chunk_list.append({
                    "id": result.id,
                    "updated_at": result.updated_at
                })

        return await utils.web.api_response(1, conversation_chunk_list, request=request)

    @staticmethod
    @utils.web.token_auth
    async def get_conversation_chunk(request: web.Request):
        params = await utils.web.get_param(request, {
            "user_id": {
                "required": False,
                "type": int,
            },
            "chunk_id": {
                "required": True,
                "type": int,
            },
        })

        if request.get("caller") == "user":
            user_id = request.get("user")
        else:
            user_id = params.get("user_id")

        chunk_id = params.get("chunk_id")

        dbs = await DatabaseService.create(request.app)
        async with dbs.create_session() as session:
            stmt = select(ConversationChunkModel).where(
                ConversationChunkModel.id == chunk_id)

            conversation_data = await session.scalar(stmt)

            if conversation_data is None:
                return await utils.web.api_response(-1, error={
                    "code": "conversation-chunk-not-found",
                    "message": "Conversation chunk not found."
                }, http_status=404, request=request)

            if conversation_data.conversation.user_id != user_id:
                return await utils.web.api_response(-1, error={
                    "code": "permission-denied",
                    "message": "Permission denied."
                }, http_status=403, request=request)

        return await utils.web.api_response(1, conversation_data.__dict__, request=request)
    
    @staticmethod
    @utils.web.token_auth
    async def get_tokens(request: web.Request):
        params = await utils.web.get_param(request, {
            "question": {
                "type": str,
                "required": True
            }
        })

        question = params.get("question")

        tiktoken = await TikTokenService.create()
        tokens = await tiktoken.get_tokens(question)

        return await utils.web.api_response(1, {"tokens": tokens}, request=request)

    @staticmethod
    @utils.web.token_auth
    async def chat_complete(request: web.Request):
        params = await utils.web.get_param(request, {
            "title": {
                "type": str,
                "required": True,
            },
            "question": {
                "type": str,
                "required": True,
            },
            "conversation_id": {
                "type": int,
                "required": False,
            },
            "extra_limit": {
                "type": int,
                "required": False,
                "default": 10,
            },
            "in_collection": {
                "type": bool,
                "required": False,
                "default": False,
            },
        })

        user_id = request.get("user")
        caller = request.get("caller")

        page_title = params.get("title")
        question = params.get("question")
        conversation_id = params.get("conversation_id")

        extra_limit = params.get("extra_limit")
        in_collection = params.get("in_collection")

        dbs = await DatabaseService.create(request.app)
        tiktoken = await TikTokenService.create()
        mwapi = MediaWikiApi.create()
        if utils.web.is_websocket(request):
            ws = web.WebSocketResponse()
            await ws.prepare(request)

            try:
                async with ChatCompleteService(dbs, page_title) as chat_complete_service:
                    if await chat_complete_service.page_index_exists():
                        tokens = await tiktoken.get_tokens(question)

                        transatcion_id = None
                        point_cost = 0
                        if request.get("caller") == "user":
                            usage_res = await mwapi.chat_complete_start_transaction(user_id, "chatcomplete", tokens, extra_limit)
                            transatcion_id = usage_res.get("transaction_id")
                            point_cost = usage_res.get("point_cost")

                        async def on_message(text: str):
                            # Send message to client, start with "+" to indicate it's a message
                            # use json will make the package 10x larger
                            await ws.send_str("+" + text)

                        async def on_extracted_doc(doc: list):
                            await ws.send_json({
                                'event': 'extract_doc',
                                'status': 1,
                                'doc': doc
                            })

                        try:
                            chat_res = await chat_complete_service \
                                .prepare_chat_complete(question, conversation_id=conversation_id, user_id=user_id, embedding_search={
                                                "limit": extra_limit,
                                                "in_collection": in_collection,
                                            })
                            await ws.send_json({
                                'event': 'done',
                                'status': 1,
                                **chat_res,
                            })

                            await chat_complete_service.set_latest_point_cost(point_cost)

                            if transatcion_id:
                                result = await mwapi.chat_complete_end_transaction(transatcion_id, chat_res["total_tokens"])
                        except Exception as e:
                            err_msg = f"Error while processing chat complete request: {e}"
                            traceback.print_exc()

                            if not ws.closed:
                                await ws.send_json({
                                    'event': 'error',
                                    'status': -1,
                                    'message': err_msg,
                                    'error': {
                                        'code': 'internal_error',
                                        'title': page_title,
                                    },
                                })
                            if transatcion_id:
                                result = await mwapi.chat_complete_cancel_transaction(transatcion_id, error=err_msg)
                    else:
                        await ws.send_json({
                            'event': 'error',
                            'status': -2,
                            'message': "Page index not found.",
                            'error': {
                                'code': 'page_not_found',
                                'title': page_title,
                            },
                        })

            # websocket closed
            except Exception as e:
                err_msg = f"Error while processing chat complete request: {e}"
                traceback.print_exc()

                if not ws.closed:
                    await ws.send_json({
                        'event': 'error',
                        'status': -1,
                        'message': err_msg,
                        'error': {
                            'code': 'internal_error',
                            'title': page_title,
                        },
                    })
            finally:
                if not ws.closed:
                    await ws.close()
        else:
            return await utils.web.api_response(-1, request=request, error={
                "code": "protocol-mismatch",
                "message": "Protocol mismatch, websocket request expected."
            }, http_status=400)