You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
5.6 KiB
Python

from __future__ import annotations
from copy import deepcopy
import time
from libs.config import Config
import asyncio
import random
import threading
from typing import Callable, Optional, TypedDict
import ollama
from utils.local import loop
from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService
LOCAL_EMBEDDING_QUEUE_TIMEOUT = 1
class OllamaEmbeddingQueueTaskInfo(TypedDict):
task_id: int
text: str
embedding: list
class OllamaEmbeddingQueue:
def __init__(self, model: str) -> None:
self.model_name = model
self.task_map: dict[int, OllamaEmbeddingQueueTaskInfo] = {}
self.task_list: list[OllamaEmbeddingQueueTaskInfo] = []
self.lock = threading.Lock()
self.thread: Optional[threading.Thread] = None
self.running = False
def post_init(self):
pass
async def get_embeddings(self, text: str):
task_id = random.randint(0, 1000000000)
with self.lock:
while task_id in self.task_map:
task_id = random.randint(0, 1000000000)
task_info = {
"task_id": task_id,
"text": text,
"embedding": None
}
self.task_map[task_id] = task_info
self.task_list.append(task_info)
self.start_queue()
while True:
task_info = self.pop_task(task_id)
if task_info is not None:
return task_info["embedding"]
await asyncio.sleep(0.01)
def pop_task(self, task_id):
with self.lock:
if task_id in self.task_map:
task_info = self.task_map[task_id]
if task_info["embedding"] is not None:
del self.task_map[task_id]
return task_info
return None
def run(self):
running = True
last_task_time = None
while running and self.running:
current_time = time.time()
task = None
with self.lock:
if len(self.task_list) > 0:
task = self.task_list.pop(0)
if task is not None:
res = ollama.embeddings(model=self.model_name, prompt=task["text"])
with self.lock:
task["embedding"] = res.embedding
last_task_time = time.time()
elif last_task_time is not None and current_time > last_task_time + LOCAL_EMBEDDING_QUEUE_TIMEOUT:
self.thread = None
self.running = False
running = False
else:
time.sleep(0.01)
def start_queue(self):
if not self.running:
self.running = True
self.thread = threading.Thread(target=self.run)
self.thread.start()
class TextEmbeddingService:
instance = None
def __init__(self):
self.tiktoken: TikTokenService = None
self.ollama_queue: OllamaEmbeddingQueue = None
self.openai_api: OpenAIApi = None
@staticmethod
async def create() -> TextEmbeddingService:
if TextEmbeddingService.instance is None:
TextEmbeddingService.instance = TextEmbeddingService()
await TextEmbeddingService.instance.init()
return TextEmbeddingService.instance
async def init(self):
self.tiktoken = await TikTokenService.create()
self.embedding_type = Config.get("embedding.type", "ollama")
if self.embedding_type == "ollama":
embedding_model = Config.get("embedding.embedding_model", "shaw/dmeta-embedding-zh")
self.ollama_queue = OllamaEmbeddingQueue(model=embedding_model)
await loop.run_in_executor(None, self.ollama_queue.post_init)
elif self.embedding_type == "openai":
api_id = Config.get("embedding.api_id")
self.openai_api: OpenAIApi = await OpenAIApi.create(api_id)
async def get_ollama_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
total_token_usage = 0
for index, doc in enumerate(doc_list):
text = doc["text"]
embedding = await self.ollama_queue.get_embeddings(text)
doc["embedding"] = embedding
total_token_usage += await self.tiktoken.get_tokens(text)
if on_index_progress is not None:
await on_index_progress(index, len(doc_list))
return (doc_list, total_token_usage)
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
res_doc_list = deepcopy(doc_list)
regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
for doc in res_doc_list:
text: str = doc["text"]
text = text.replace("\r\n", "\n").replace("\r", "\n")
if "\n" in text:
lines = text.split("\n")
new_lines = []
for line in lines:
line = line.strip()
# Add a dot at the end of the line if it doesn't end with a punctuation mark
if len(line) > 0 and regex.find(line[-1]) == -1:
line += "."
new_lines.append(line)
text = " ".join(new_lines)
doc["text"] = text
if self.embedding_type == "ollama":
return await self.get_ollama_embeddings(res_doc_list, on_index_progress)
elif self.embedding_type == "openai":
return await self.openai_api.get_embeddings(res_doc_list, on_index_progress)