You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
5.3 KiB
Python
152 lines
5.3 KiB
Python
from __future__ import annotations
|
|
from copy import deepcopy
|
|
import time
|
|
from lib.config import Config
|
|
import asyncio
|
|
import random
|
|
import threading
|
|
from typing import Callable, Optional, TypedDict
|
|
import torch
|
|
from text2vec import SentenceModel
|
|
from utils.local import loop
|
|
from service.openai_api import OpenAIApi
|
|
from service.tiktoken import TikTokenService
|
|
|
|
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
|
|
|
|
class Text2VecEmbeddingQueueTaskInfo(TypedDict):
|
|
task_id: int
|
|
text: str
|
|
embedding: torch.Tensor
|
|
|
|
class Text2VecEmbeddingQueue:
|
|
def __init__(self, model: str) -> None:
|
|
self.model_name = model
|
|
|
|
self.embedding_model = SentenceModel(self.model_name)
|
|
self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {}
|
|
self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = []
|
|
self.lock = threading.Lock()
|
|
|
|
self.thread: Optional[threading.Thread] = None
|
|
self.running = False
|
|
|
|
|
|
async def get_embeddings(self, text: str):
|
|
task_id = random.randint(0, 1000000000)
|
|
with self.lock:
|
|
while task_id in self.task_map:
|
|
task_id = random.randint(0, 1000000000)
|
|
|
|
task_info = {
|
|
"task_id": task_id,
|
|
"text": text,
|
|
"embedding": None
|
|
}
|
|
self.task_map[task_id] = task_info
|
|
self.task_list.append(task_info)
|
|
|
|
self.start_queue()
|
|
while True:
|
|
task_info = self.pop_task(task_id)
|
|
if task_info is not None:
|
|
return task_info["embedding"]
|
|
|
|
await asyncio.sleep(0.01)
|
|
|
|
def pop_task(self, task_id):
|
|
with self.lock:
|
|
if task_id in self.task_map:
|
|
task_info = self.task_map[task_id]
|
|
if task_info["embedding"] is not None:
|
|
del self.task_map[task_id]
|
|
return task_info
|
|
|
|
return None
|
|
|
|
def run(self):
|
|
running = True
|
|
last_task_time = None
|
|
while running and self.running:
|
|
current_time = time.time()
|
|
task = None
|
|
with self.lock:
|
|
if len(self.task_list) > 0:
|
|
task = self.task_list.pop(0)
|
|
|
|
if task is not None:
|
|
embeddings = self.embedding_model.encode([task["text"]])
|
|
|
|
with self.lock:
|
|
task["embedding"] = embeddings[0]
|
|
|
|
last_task_time = time.time()
|
|
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
|
|
self.thread = None
|
|
self.running = False
|
|
running = False
|
|
else:
|
|
time.sleep(0.01)
|
|
|
|
def start_queue(self):
|
|
if not self.running:
|
|
self.running = True
|
|
self.thread = threading.Thread(target=self.run)
|
|
self.thread.start()
|
|
|
|
class TextEmbeddingService:
|
|
instance = None
|
|
|
|
@staticmethod
|
|
async def create() -> TextEmbeddingService:
|
|
if TextEmbeddingService.instance is None:
|
|
TextEmbeddingService.instance = TextEmbeddingService()
|
|
await TextEmbeddingService.instance.init()
|
|
return TextEmbeddingService.instance
|
|
|
|
async def init(self):
|
|
self.tiktoken = await TikTokenService.create()
|
|
|
|
self.embedding_type = Config.get("embedding.type", "text2vec")
|
|
|
|
if self.embedding_type == "text2vec":
|
|
embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese")
|
|
self.embedding_queue = Text2VecEmbeddingQueue(model=embedding_model)
|
|
elif self.embedding_type == "openai":
|
|
self.openai_api: OpenAIApi = await OpenAIApi.create()
|
|
|
|
await loop.run_in_executor(None, self.embedding_queue.init)
|
|
|
|
async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
|
|
for index, doc in enumerate(doc_list):
|
|
text = doc["text"]
|
|
embedding = await self.embedding_queue.get_embeddings(text)
|
|
doc["embedding"] = embedding
|
|
|
|
if on_index_progress is not None:
|
|
await on_index_progress(index, len(doc_list))
|
|
|
|
|
|
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
|
|
res_doc_list = deepcopy(doc_list)
|
|
|
|
regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
|
|
for doc in res_doc_list:
|
|
text: str = doc["text"]
|
|
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
|
if "\n" in text:
|
|
lines = text.split("\n")
|
|
new_lines = []
|
|
for line in lines:
|
|
line = line.strip()
|
|
# Add a dot at the end of the line if it doesn't end with a punctuation mark
|
|
if len(line) > 0 and regex.find(line[-1]) == -1:
|
|
line += "."
|
|
new_lines.append(line)
|
|
text = " ".join(new_lines)
|
|
doc["text"] = text
|
|
|
|
if self.embedding_type == "text2vec":
|
|
return await self.get_text2vec_embeddings(res_doc_list, on_index_progress)
|
|
elif self.embedding_type == "openai":
|
|
return await self.openai_api.get_embeddings(res_doc_list, on_index_progress) |