from __future__ import annotations from copy import deepcopy import time from libs.config import Config import asyncio import random import threading from typing import Callable, Optional, TypedDict import torch from text2vec import SentenceModel from utils.local import loop from service.openai_api import OpenAIApi from service.tiktoken import TikTokenService BERT_EMBEDDING_QUEUE_TIMEOUT = 1 class Text2VecEmbeddingQueueTaskInfo(TypedDict): task_id: int text: str embedding: torch.Tensor class Text2VecEmbeddingQueue: def __init__(self, model: str) -> None: self.model_name = model self.embedding_model: SentenceModel = None self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {} self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = [] self.lock = threading.Lock() self.thread: Optional[threading.Thread] = None self.running = False def post_init(self): self.embedding_model = SentenceModel(self.model_name) async def get_embeddings(self, text: str): task_id = random.randint(0, 1000000000) with self.lock: while task_id in self.task_map: task_id = random.randint(0, 1000000000) task_info = { "task_id": task_id, "text": text, "embedding": None } self.task_map[task_id] = task_info self.task_list.append(task_info) self.start_queue() while True: task_info = self.pop_task(task_id) if task_info is not None: return task_info["embedding"] await asyncio.sleep(0.01) def pop_task(self, task_id): with self.lock: if task_id in self.task_map: task_info = self.task_map[task_id] if task_info["embedding"] is not None: del self.task_map[task_id] return task_info return None def run(self): running = True last_task_time = None while running and self.running: current_time = time.time() task = None with self.lock: if len(self.task_list) > 0: task = self.task_list.pop(0) if task is not None: embeddings = self.embedding_model.encode([task["text"]]) with self.lock: task["embedding"] = embeddings[0] last_task_time = time.time() elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT: self.thread = None self.running = False running = False else: time.sleep(0.01) def start_queue(self): if not self.running: self.running = True self.thread = threading.Thread(target=self.run) self.thread.start() class TextEmbeddingService: instance = None def __init__(self): self.tiktoken: TikTokenService = None self.text2vec_queue: Text2VecEmbeddingQueue = None self.openai_api: OpenAIApi = None @staticmethod async def create() -> TextEmbeddingService: if TextEmbeddingService.instance is None: TextEmbeddingService.instance = TextEmbeddingService() await TextEmbeddingService.instance.init() return TextEmbeddingService.instance async def init(self): self.tiktoken = await TikTokenService.create() self.embedding_type = Config.get("embedding.type", "text2vec") if self.embedding_type == "text2vec": embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese") self.text2vec_queue = Text2VecEmbeddingQueue(model=embedding_model) elif self.embedding_type == "openai": api_id = Config.get("embedding.api_id") self.openai_api: OpenAIApi = await OpenAIApi.create(api_id) await loop.run_in_executor(None, self.text2vec_queue.post_init) async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): total_token_usage = 0 for index, doc in enumerate(doc_list): text = doc["text"] embedding = await self.text2vec_queue.get_embeddings(text) doc["embedding"] = embedding total_token_usage += await self.tiktoken.get_tokens(text) if on_index_progress is not None: await on_index_progress(index, len(doc_list)) return (doc_list, total_token_usage) async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): res_doc_list = deepcopy(doc_list) regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]" for doc in res_doc_list: text: str = doc["text"] text = text.replace("\r\n", "\n").replace("\r", "\n") if "\n" in text: lines = text.split("\n") new_lines = [] for line in lines: line = line.strip() # Add a dot at the end of the line if it doesn't end with a punctuation mark if len(line) > 0 and regex.find(line[-1]) == -1: line += "." new_lines.append(line) text = " ".join(new_lines) doc["text"] = text if self.embedding_type == "text2vec": return await self.get_text2vec_embeddings(res_doc_list, on_index_progress) elif self.embedding_type == "openai": return await self.openai_api.get_embeddings(res_doc_list, on_index_progress)