from __future__ import annotations
import asyncio
import json
from typing import Callable, Optional, TypedDict

import aiohttp
from config import Config
import numpy as np
from aiohttp_sse_client2 import client as sse_client

from service.tiktoken import TikTokenService

class ChatCompleteMessageLog(TypedDict):
    role: str
    content: str

class ChatCompleteResponse(TypedDict):
    message: str
    prompt_tokens: int
    message_tokens: int
    total_tokens: int
    finish_reason: str

class OpenAIApi:
    @staticmethod
    def create():
        return OpenAIApi()

    def __init__(self):
        self.api_type = Config.get("chatcomplete.api_type", "openai", str)
        self.request_proxy = Config.get("chatcomplete.request_proxy", type=str, empty_is_none=True)

        if self.api_type == "azure":
            self.api_url = Config.get("chatcomplete.azure.api_endpoint", type=str)
            self.api_key = Config.get("chatcomplete.azure.key", type=str)
        else:
            self.api_url = Config.get("chatcomplete.openai.api_endpoint", type=str)
            self.api_key = Config.get("chatcomplete.openai.key", type=str)

    def build_header(self):
        if self.api_type == "azure":
            return {
                "content-type": "application/json",
                "accept": "application/json",
                "api-key": self.api_key
            }
        else:
            return {
                "authorization": f"Bearer {self.api_key}",
                "content-type": "application/json",
                "accept": "application/json",
            }
        
    def get_url(self, method: str):
        if self.api_type == "azure":
            deployments = Config.get("chatcomplete.azure.deployments")
            if method == "chat/completions":
                return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method
            elif method == "embeddings":
                return self.api_url + "/openai/deployments/" + deployments["embedding"] + "/" + method
        else:
            return self.api_url + "/v1/" + method

    async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
        text_list = []
        regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
        for doc in doc_list:
            text: str = doc["text"]
            text = text.replace("\r\n", "\n").replace("\r", "\n")
            if "\n" in text:
                lines = text.split("\n")
                new_lines = []
                for line in lines:
                    line = line.strip()
                    # Add a dot at the end of the line if it doesn't end with a punctuation mark
                    if len(line) > 0 and regex.find(line[-1]) == -1:
                        line += "."
                    new_lines.append(line)
                text = " ".join(new_lines)
            text_list.append(text)

        token_usage = 0

        async with aiohttp.ClientSession() as session:
            url = self.get_url("embeddings")
            params = {}
            post_data = {
                "input": text_list,
            }

            if self.api_type == "azure":
                params["api-version"] = "2023-05-15"
            else:
                post_data["model"] = "text-embedding-ada-002"

            if self.api_type == "azure":
                # Azure api does not support batch
                for index, text in enumerate(text_list):
                    retry_num = 0
                    max_retry_num = 3
                    while retry_num < max_retry_num:
                        try:
                            async with session.post(url,
                                    headers=self.build_header(),
                                    params=params,
                                    json={"input": text},
                                    timeout=30,
                                    proxy=self.request_proxy) as resp:
                                
                                data = await resp.json()

                                one_data = data["data"]
                                if len(one_data) > 0:
                                    embedding = one_data[0]["embedding"]
                                    if embedding is not None:
                                        embedding = np.array(embedding)
                                    doc_list[index]["embedding"] = embedding

                            token_usage += int(data["usage"]["total_tokens"])

                            if on_index_progress is not None:
                                await on_index_progress(index, len(text_list))

                            break
                        except Exception as e:
                            retry_num += 1

                            if retry_num >= max_retry_num:
                                raise e
                            
                            print("Error: %s" % e)
                            print("Retrying...")
                            await asyncio.sleep(0.5)
            else:
                async with session.post(url,
                                        headers=self.build_header(),
                                        params=params,
                                        json=post_data,
                                        timeout=30,
                                        proxy=self.request_proxy) as resp:

                    data = await resp.json()

                    if "error" in data:
                        raise Exception(data["error"])

                    for one_data in data["data"]:
                        embedding = one_data["embedding"]
                        index = one_data["index"]
                        
                        if index < len(doc_list):
                            if embedding is not None:
                                embedding = np.array(embedding)
                            doc_list[index]["embedding"] = embedding

                    token_usage = int(data["usage"]["total_tokens"])

            if on_index_progress is not None:
                await on_index_progress(index, len(text_list))

        return (doc_list, token_usage)
    
    async def make_message_list(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = []) -> list[ChatCompleteMessageLog]:
        summaryContent = None

        messageList: list[ChatCompleteMessageLog] = []
        for message in conversation:
            if message["role"] == "summary":
                summaryContent = message["content"]
            elif message["role"] == "user" or message["role"] == "assistant":
                messageList.append(message)

        if summaryContent:
            system_prompt += "\n\n" + summaryContent

        messageList.insert(0, ChatCompleteMessageLog(role="assistant", content=system_prompt))
        messageList.append(ChatCompleteMessageLog(role="user", content=question))

        return messageList
    
    async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None):
        messageList = await self.make_message_list(question, system_prompt, conversation)

        url = self.get_url("chat/completions")

        params = {}
        post_data = {
            "messages": messageList,
            "user": user,
        }

        if self.api_type == "azure":
            params["api-version"] = "2023-05-15"
        else:
            post_data["model"] = "gpt-3.5-turbo"

        post_data = {k: v for k, v in post_data.items() if v is not None}

        async with aiohttp.ClientSession() as session:
            async with session.post(url,
                                    headers=self.build_header(),
                                    params=params,
                                    json=post_data,
                                    timeout=30,
                                    proxy=self.request_proxy) as resp:
                
                data = await resp.json()

                if "choices" in data and len(data["choices"]) > 0:
                    choice = data["choices"][0]

                    message = choice["message"]["content"]
                    finish_reason = choice["finish_reason"]

                    prompt_tokens = int(data["usage"]["prompt_tokens"])
                    message_tokens = int(data["usage"]["completion_tokens"])
                    total_tokens = int(data["usage"]["total_tokens"])

                    return ChatCompleteResponse(message=message,
                                                prompt_tokens=prompt_tokens,
                                                message_tokens=message_tokens,
                                                total_tokens=total_tokens,
                                                finish_reason=finish_reason)
                
        return None

    async def chat_complete_stream(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None):
        tiktoken = await TikTokenService.create()

        messageList = await self.make_message_list(question, system_prompt, conversation)

        prompt_tokens = 0
        for message in messageList:
            prompt_tokens += await tiktoken.get_tokens(message["content"])

        url = self.get_url("chat/completions")

        params = {}
        post_data = {
            "messages": messageList,
            "user": user,
            "stream": True,
            "n": 1,
            "max_tokens": 768,
            "stop": None,
            "temperature": 1,
            "top_p": 0.95
        }

        if self.api_type == "azure":
            params["api-version"] = "2023-05-15"
        else:
            post_data["model"] = "gpt-3.5-turbo"

        post_data = {k: v for k, v in post_data.items() if v is not None}

        res_message: list[str] = []
        finish_reason = None

        async with sse_client.EventSource(
            url,
            option={
                "method": "POST"
            },
            headers=self.build_header(),
            params=params,
            json=post_data,
            proxy=self.request_proxy
        ) as session:
            async for event in session:
                """
                {"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
                {"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Thank"},"index":0,"finish_reason":null}]}
                {"id":"something","object":"chat.completion.chunk","created":1681261845,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{},"index":0,"finish_reason":"stop"}]}
                [DONE]
                """
                content_started = False
                
                event_data = event.data.strip()
                
                if event_data == "[DONE]":
                    break
                elif event_data[0] == "{" and event_data[-1] == "}":
                    data = json.loads(event_data)
                    if "choices" in data and len(data["choices"]) > 0:
                        choice = data["choices"][0]

                        if choice["finish_reason"] is not None:
                            finish_reason = choice["finish_reason"]

                        delta_content = choice["delta"]
                        if "content" in delta_content:
                            delta_message: str = delta_content["content"]

                            # Skip empty lines before content
                            if not content_started:
                                if delta_message.replace("\n", "") == "":
                                    continue
                                else:
                                    content_started = True

                            res_message.append(delta_message)

                            # if config.DEBUG:
                            #     print(delta_message, end="", flush=True)

                            if on_message is not None:
                                await on_message(delta_message)
                    
                    if finish_reason is not None:
                        break

        res_message_str = "".join(res_message)
        message_tokens = await tiktoken.get_tokens(res_message_str)
        total_tokens = prompt_tokens + message_tokens

        return ChatCompleteResponse(message=res_message_str,
                                    prompt_tokens=prompt_tokens,
                                    message_tokens=message_tokens,
                                    total_tokens=total_tokens,
                                    finish_reason=finish_reason)