from __future__ import annotations
import time
from config import Config
import asyncio
import random
import threading
from typing import Optional, TypedDict
import torch
from transformers import pipeline
from local import loop

from service.tiktoken import TikTokenService

BERT_EMBEDDING_QUEUE_TIMEOUT = 1

class BERTEmbeddingQueueTaskInfo(TypedDict):
    task_id: int
    text: str
    embedding: torch.Tensor

class BERTEmbeddingQueue:
    def init(self):
        self.embedding_model = pipeline("feature-extraction", model="bert-base-chinese")
        self.task_map: dict[int, BERTEmbeddingQueueTaskInfo] = {}
        self.task_list: list[BERTEmbeddingQueueTaskInfo] = []
        self.lock = threading.Lock()

        self.thread: Optional[threading.Thread] = None
        self.running = False

    async def get_embeddings(self, text: str):
        text = "[CLS]" + text + "[SEP]"
        task_id = random.randint(0, 1000000000)
        with self.lock:
            while task_id in self.task_map:
                task_id = random.randint(0, 1000000000)

            task_info = {
                "task_id": task_id,
                "text": text,
                "embedding": None
            }
            self.task_map[task_id] = task_info
            self.task_list.append(task_info)

        self.start_queue()
        while True:
            task_info = self.pop_task(task_id)
            if task_info is not None:
                return task_info["embedding"]

            await asyncio.sleep(0.01)
    
    def pop_task(self, task_id):
        with self.lock:
            if task_id in self.task_map:
                task_info = self.task_map[task_id]
                if task_info["embedding"] is not None:
                    del self.task_map[task_id]
                    return task_info
                
            return None

    def run(self):
        running = True
        last_task_time = None
        while running and self.running:
            current_time = time.time()
            task = None
            with self.lock:
                if len(self.task_list) > 0:
                    task = self.task_list.pop(0)

            if task is not None:
                embeddings = self.embedding_model(task["text"])

                with self.lock:
                    task["embedding"] = embeddings[0][1]

                last_task_time = time.time()
            elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
                self.thread = None
                self.running = False
                running = False
            else:
                time.sleep(0.01)
        
    def start_queue(self):
        if not self.running:
            self.running = True
            self.thread = threading.Thread(target=self.run)
            self.thread.start()

bert_embedding_queue = BERTEmbeddingQueue()
bert_embedding_queue.init()

class BERTEmbeddingService:
    instance = None

    @staticmethod
    async def create() -> BERTEmbeddingService:
        if BERTEmbeddingService.instance is None:
            BERTEmbeddingService.instance = BERTEmbeddingService()
            await BERTEmbeddingService.instance.init()
        return BERTEmbeddingService.instance

    async def init(self):
        self.tiktoken = await TikTokenService.create()
        self.embedding_queue = BERTEmbeddingQueue()
        await loop.run_in_executor(None, self.embedding_queue.init)

    async def get_embeddings(self, docs, on_progress=None):
        if len(docs) == 0:
            return ([], 0)

        if on_progress is not None:
            await on_progress(0, len(docs))

        embeddings = []
        token_usage = 0

        for doc in docs:
            if "text" in doc:
                tokens = await self.tiktoken.get_tokens(doc["text"])
                token_usage += tokens
                embeddings.append({
                    "id": doc["id"],
                    "text": doc["text"],
                    "embedding": self.model.encode(doc["text"]),
                    "tokens": tokens
                })
            else:
                embeddings.append({
                    "id": doc["id"],
                    "text": doc["text"],
                    "embedding": None,
                    "tokens": 0
                })

            if on_progress is not None:
                await on_progress(1, len(docs))

        return (embeddings, token_usage)