You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
143 lines
4.3 KiB
Python
143 lines
4.3 KiB
Python
from __future__ import annotations
|
|
import time
|
|
from config import Config
|
|
import asyncio
|
|
import random
|
|
import threading
|
|
from typing import Optional, TypedDict
|
|
import torch
|
|
from transformers import pipeline
|
|
from local import loop
|
|
|
|
from service.tiktoken import TikTokenService
|
|
|
|
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
|
|
|
|
class BERTEmbeddingQueueTaskInfo(TypedDict):
|
|
task_id: int
|
|
text: str
|
|
embedding: torch.Tensor
|
|
|
|
class BERTEmbeddingQueue:
|
|
def init(self):
|
|
self.embedding_model = pipeline("feature-extraction", model="bert-base-chinese")
|
|
self.task_map: dict[int, BERTEmbeddingQueueTaskInfo] = {}
|
|
self.task_list: list[BERTEmbeddingQueueTaskInfo] = []
|
|
self.lock = threading.Lock()
|
|
|
|
self.thread: Optional[threading.Thread] = None
|
|
self.running = False
|
|
|
|
async def get_embeddings(self, text: str):
|
|
text = "[CLS]" + text + "[SEP]"
|
|
task_id = random.randint(0, 1000000000)
|
|
with self.lock:
|
|
while task_id in self.task_map:
|
|
task_id = random.randint(0, 1000000000)
|
|
|
|
task_info = {
|
|
"task_id": task_id,
|
|
"text": text,
|
|
"embedding": None
|
|
}
|
|
self.task_map[task_id] = task_info
|
|
self.task_list.append(task_info)
|
|
|
|
self.start_queue()
|
|
while True:
|
|
task_info = self.pop_task(task_id)
|
|
if task_info is not None:
|
|
return task_info["embedding"]
|
|
|
|
await asyncio.sleep(0.01)
|
|
|
|
def pop_task(self, task_id):
|
|
with self.lock:
|
|
if task_id in self.task_map:
|
|
task_info = self.task_map[task_id]
|
|
if task_info["embedding"] is not None:
|
|
del self.task_map[task_id]
|
|
return task_info
|
|
|
|
return None
|
|
|
|
def run(self):
|
|
running = True
|
|
last_task_time = None
|
|
while running and self.running:
|
|
current_time = time.time()
|
|
task = None
|
|
with self.lock:
|
|
if len(self.task_list) > 0:
|
|
task = self.task_list.pop(0)
|
|
|
|
if task is not None:
|
|
embeddings = self.embedding_model(task["text"])
|
|
|
|
with self.lock:
|
|
task["embedding"] = embeddings[0][1]
|
|
|
|
last_task_time = time.time()
|
|
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
|
|
self.thread = None
|
|
self.running = False
|
|
running = False
|
|
else:
|
|
time.sleep(0.01)
|
|
|
|
def start_queue(self):
|
|
if not self.running:
|
|
self.running = True
|
|
self.thread = threading.Thread(target=self.run)
|
|
self.thread.start()
|
|
|
|
bert_embedding_queue = BERTEmbeddingQueue()
|
|
bert_embedding_queue.init()
|
|
|
|
class BERTEmbeddingService:
|
|
instance = None
|
|
|
|
@staticmethod
|
|
async def create() -> BERTEmbeddingService:
|
|
if BERTEmbeddingService.instance is None:
|
|
BERTEmbeddingService.instance = BERTEmbeddingService()
|
|
await BERTEmbeddingService.instance.init()
|
|
return BERTEmbeddingService.instance
|
|
|
|
async def init(self):
|
|
self.tiktoken = await TikTokenService.create()
|
|
self.embedding_queue = BERTEmbeddingQueue()
|
|
await loop.run_in_executor(None, self.embedding_queue.init)
|
|
|
|
async def get_embeddings(self, docs, on_progress=None):
|
|
if len(docs) == 0:
|
|
return ([], 0)
|
|
|
|
if on_progress is not None:
|
|
await on_progress(0, len(docs))
|
|
|
|
embeddings = []
|
|
token_usage = 0
|
|
|
|
for doc in docs:
|
|
if "text" in doc:
|
|
tokens = await self.tiktoken.get_tokens(doc["text"])
|
|
token_usage += tokens
|
|
embeddings.append({
|
|
"id": doc["id"],
|
|
"text": doc["text"],
|
|
"embedding": self.model.encode(doc["text"]),
|
|
"tokens": tokens
|
|
})
|
|
else:
|
|
embeddings.append({
|
|
"id": doc["id"],
|
|
"text": doc["text"],
|
|
"embedding": None,
|
|
"tokens": 0
|
|
})
|
|
|
|
if on_progress is not None:
|
|
await on_progress(1, len(docs))
|
|
|
|
return (embeddings, token_usage) |