from __future__ import annotations import time from config import Config import asyncio import random import threading from typing import Optional, TypedDict import torch from transformers import pipeline from local import loop from service.tiktoken import TikTokenService BERT_EMBEDDING_QUEUE_TIMEOUT = 1 class BERTEmbeddingQueueTaskInfo(TypedDict): task_id: int text: str embedding: torch.Tensor class BERTEmbeddingQueue: def init(self): self.embedding_model = pipeline("feature-extraction", model="bert-base-chinese") self.task_map: dict[int, BERTEmbeddingQueueTaskInfo] = {} self.task_list: list[BERTEmbeddingQueueTaskInfo] = [] self.lock = threading.Lock() self.thread: Optional[threading.Thread] = None self.running = False async def get_embeddings(self, text: str): text = "[CLS]" + text + "[SEP]" task_id = random.randint(0, 1000000000) with self.lock: while task_id in self.task_map: task_id = random.randint(0, 1000000000) task_info = { "task_id": task_id, "text": text, "embedding": None } self.task_map[task_id] = task_info self.task_list.append(task_info) self.start_queue() while True: task_info = self.pop_task(task_id) if task_info is not None: return task_info["embedding"] await asyncio.sleep(0.01) def pop_task(self, task_id): with self.lock: if task_id in self.task_map: task_info = self.task_map[task_id] if task_info["embedding"] is not None: del self.task_map[task_id] return task_info return None def run(self): running = True last_task_time = None while running and self.running: current_time = time.time() task = None with self.lock: if len(self.task_list) > 0: task = self.task_list.pop(0) if task is not None: embeddings = self.embedding_model(task["text"]) with self.lock: task["embedding"] = embeddings[0][1] last_task_time = time.time() elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT: self.thread = None self.running = False running = False else: time.sleep(0.01) def start_queue(self): if not self.running: self.running = True self.thread = threading.Thread(target=self.run) self.thread.start() bert_embedding_queue = BERTEmbeddingQueue() bert_embedding_queue.init() class BERTEmbeddingService: instance = None @staticmethod async def create() -> BERTEmbeddingService: if BERTEmbeddingService.instance is None: BERTEmbeddingService.instance = BERTEmbeddingService() await BERTEmbeddingService.instance.init() return BERTEmbeddingService.instance async def init(self): self.tiktoken = await TikTokenService.create() self.embedding_queue = BERTEmbeddingQueue() await loop.run_in_executor(None, self.embedding_queue.init) async def get_embeddings(self, docs, on_progress=None): if len(docs) == 0: return ([], 0) if on_progress is not None: await on_progress(0, len(docs)) embeddings = [] token_usage = 0 for doc in docs: if "text" in doc: tokens = await self.tiktoken.get_tokens(doc["text"]) token_usage += tokens embeddings.append({ "id": doc["id"], "text": doc["text"], "embedding": self.model.encode(doc["text"]), "tokens": tokens }) else: embeddings.append({ "id": doc["id"], "text": doc["text"], "embedding": None, "tokens": 0 }) if on_progress is not None: await on_progress(1, len(docs)) return (embeddings, token_usage)