完成ChatComplete重构
parent
2a1e5c1589
commit
1355dbbf35
@ -1,56 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from typing import TypeVar
|
|
||||||
import toml
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
values: dict = {}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_config(file):
|
|
||||||
with open(file, "r", encoding="utf-8") as f:
|
|
||||||
Config.values = toml.load(f)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get(key: str, default=None, type=None, empty_is_none=False):
|
|
||||||
key_path = key.split(".")
|
|
||||||
value = Config.values
|
|
||||||
for k in key_path:
|
|
||||||
if k in value:
|
|
||||||
value = value[k]
|
|
||||||
else:
|
|
||||||
return default
|
|
||||||
|
|
||||||
if empty_is_none and value == "":
|
|
||||||
return None
|
|
||||||
|
|
||||||
if type == bool:
|
|
||||||
if isinstance(value, bool):
|
|
||||||
return value
|
|
||||||
elif isinstance(value, int) or isinstance(value, float):
|
|
||||||
return value != 0
|
|
||||||
else:
|
|
||||||
return str(value).lower() in ("yes", "true", "1")
|
|
||||||
elif type == int:
|
|
||||||
return int(value)
|
|
||||||
elif type == float:
|
|
||||||
return float(value)
|
|
||||||
elif type == str:
|
|
||||||
return str(value)
|
|
||||||
elif type == list:
|
|
||||||
if not isinstance(value, list):
|
|
||||||
return []
|
|
||||||
elif type == dict:
|
|
||||||
if not isinstance(value, dict):
|
|
||||||
return {}
|
|
||||||
else:
|
|
||||||
return value
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def set(key: str, value):
|
|
||||||
key_path = key.split(".")
|
|
||||||
obj = Config.values
|
|
||||||
for k in key_path[:-1]:
|
|
||||||
if k not in obj:
|
|
||||||
obj[k] = {}
|
|
||||||
obj = obj[k]
|
|
||||||
obj[key_path[-1]] = value
|
|
@ -1,6 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from noawait import NoAwaitPool
|
|
||||||
|
|
||||||
debug = False
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
noawait = NoAwaitPool(loop)
|
|
@ -1,165 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from asyncio import AbstractEventLoop, Task
|
|
||||||
import asyncio
|
|
||||||
import atexit
|
|
||||||
from functools import wraps
|
|
||||||
import random
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from typing import Callable, Coroutine, Optional, TypedDict
|
|
||||||
|
|
||||||
class TimerInfo(TypedDict):
|
|
||||||
id: int
|
|
||||||
callback: Callable
|
|
||||||
interval: float
|
|
||||||
next_time: float
|
|
||||||
|
|
||||||
class NoAwaitPool:
|
|
||||||
def __init__(self, loop: AbstractEventLoop):
|
|
||||||
self.task_list: list[Task] = []
|
|
||||||
self.timer_map: dict[int, TimerInfo] = {}
|
|
||||||
self.loop = loop
|
|
||||||
self.running = True
|
|
||||||
|
|
||||||
self.should_refresh_task = False
|
|
||||||
self.next_timer_time: Optional[float] = None
|
|
||||||
|
|
||||||
self.on_error: list[Callable] = []
|
|
||||||
|
|
||||||
self.gc_task = loop.create_task(self._run_gc())
|
|
||||||
self.timer_task = loop.create_task(self._run_timer())
|
|
||||||
|
|
||||||
atexit.register(self.end_task)
|
|
||||||
|
|
||||||
async def end(self):
|
|
||||||
if self.running:
|
|
||||||
print("Stopping NoAwait Tasks...")
|
|
||||||
self.running = False
|
|
||||||
for task in self.task_list:
|
|
||||||
await self._finish_task(task)
|
|
||||||
|
|
||||||
await self.gc_task
|
|
||||||
await self.timer_task
|
|
||||||
|
|
||||||
def end_task(self):
|
|
||||||
if self.running and not self.loop.is_closed():
|
|
||||||
self.loop.run_until_complete(self.end())
|
|
||||||
|
|
||||||
async def _wrap_task(self, task: Task):
|
|
||||||
try:
|
|
||||||
await task
|
|
||||||
except Exception as e:
|
|
||||||
handled = False
|
|
||||||
for handler in self.on_error:
|
|
||||||
try:
|
|
||||||
handler_ret = handler(e)
|
|
||||||
await handler_ret
|
|
||||||
handled = True
|
|
||||||
except Exception as handler_err:
|
|
||||||
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if not handled:
|
|
||||||
print(e, file=sys.stderr)
|
|
||||||
traceback.print_exc()
|
|
||||||
finally:
|
|
||||||
self.should_refresh_task = True
|
|
||||||
|
|
||||||
def add_task(self, coroutine: Coroutine):
|
|
||||||
task = self.loop.create_task(coroutine)
|
|
||||||
self.task_list.append(task)
|
|
||||||
|
|
||||||
def add_timer(self, callback: Callable, interval: float) -> int:
|
|
||||||
id = random.randint(0, 1000000000)
|
|
||||||
while id in self.timer_map:
|
|
||||||
id = random.randint(0, 1000000000)
|
|
||||||
|
|
||||||
now = self.loop.time()
|
|
||||||
next_time = now + interval
|
|
||||||
self.timer_map[id] = {
|
|
||||||
"id": id,
|
|
||||||
"callback": callback,
|
|
||||||
"interval": interval,
|
|
||||||
"next_time": next_time
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.next_timer_time is None or next_time < self.next_timer_time:
|
|
||||||
self.next_timer_time = next_time
|
|
||||||
|
|
||||||
return id
|
|
||||||
|
|
||||||
def remove_timer(self, id: int):
|
|
||||||
if id in self.timer_map:
|
|
||||||
del self.timer_map[id]
|
|
||||||
|
|
||||||
def wrap(self, f):
|
|
||||||
@wraps(f)
|
|
||||||
def decorated_function(*args, **kwargs):
|
|
||||||
coroutine = f(*args, **kwargs)
|
|
||||||
self.add_task(coroutine)
|
|
||||||
|
|
||||||
return decorated_function
|
|
||||||
|
|
||||||
async def _finish_task(self, task: Task):
|
|
||||||
try:
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
await task
|
|
||||||
except Exception as e:
|
|
||||||
handled = False
|
|
||||||
for handler in self.on_error:
|
|
||||||
try:
|
|
||||||
handler_ret = handler(e)
|
|
||||||
await handler_ret
|
|
||||||
handled = True
|
|
||||||
except Exception as handler_err:
|
|
||||||
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if not handled:
|
|
||||||
print(e, file=sys.stderr)
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
async def _run_gc(self):
|
|
||||||
while self.running:
|
|
||||||
if self.should_refresh_task:
|
|
||||||
should_remove = []
|
|
||||||
for task in self.task_list:
|
|
||||||
if task.done():
|
|
||||||
await self._finish_task(task)
|
|
||||||
should_remove.append(task)
|
|
||||||
for task in should_remove:
|
|
||||||
self.task_list.remove(task)
|
|
||||||
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
async def _run_timer(self):
|
|
||||||
while self.running:
|
|
||||||
now = self.loop.time()
|
|
||||||
if self.next_timer_time is not None and now >= self.next_timer_time:
|
|
||||||
self.next_timer_time = None
|
|
||||||
for timer in self.timer_map.values():
|
|
||||||
if now >= timer["next_time"]:
|
|
||||||
timer["next_time"] = now + timer["interval"]
|
|
||||||
try:
|
|
||||||
result = timer["callback"]()
|
|
||||||
self.add_task(result)
|
|
||||||
except Exception as e:
|
|
||||||
handled = False
|
|
||||||
for handler in self.on_error:
|
|
||||||
try:
|
|
||||||
handler_ret = handler(e)
|
|
||||||
self.add_task(handler_ret)
|
|
||||||
handled = True
|
|
||||||
except Exception as handler_err:
|
|
||||||
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if not handled:
|
|
||||||
print(e, file=sys.stderr)
|
|
||||||
traceback.print_exc()
|
|
||||||
if self.next_timer_time is None or timer["next_time"] < self.next_timer_time:
|
|
||||||
self.next_timer_time = timer["next_time"]
|
|
||||||
|
|
||||||
await asyncio.sleep(0.1)
|
|
@ -1,5 +1,3 @@
|
|||||||
transformers
|
text2vec>=1.2.9
|
||||||
--index-url https://download.pytorch.org/whl/cpu
|
--index-url https://download.pytorch.org/whl/cpu
|
||||||
torch
|
torch
|
||||||
torchvision
|
|
||||||
torchaudio
|
|
@ -1,143 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import time
|
|
||||||
from config import Config
|
|
||||||
import asyncio
|
|
||||||
import random
|
|
||||||
import threading
|
|
||||||
from typing import Optional, TypedDict
|
|
||||||
import torch
|
|
||||||
from transformers import pipeline
|
|
||||||
from local import loop
|
|
||||||
|
|
||||||
from service.tiktoken import TikTokenService
|
|
||||||
|
|
||||||
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
|
|
||||||
|
|
||||||
class BERTEmbeddingQueueTaskInfo(TypedDict):
|
|
||||||
task_id: int
|
|
||||||
text: str
|
|
||||||
embedding: torch.Tensor
|
|
||||||
|
|
||||||
class BERTEmbeddingQueue:
|
|
||||||
def init(self):
|
|
||||||
self.embedding_model = pipeline("feature-extraction", model="bert-base-chinese")
|
|
||||||
self.task_map: dict[int, BERTEmbeddingQueueTaskInfo] = {}
|
|
||||||
self.task_list: list[BERTEmbeddingQueueTaskInfo] = []
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
|
|
||||||
self.thread: Optional[threading.Thread] = None
|
|
||||||
self.running = False
|
|
||||||
|
|
||||||
async def get_embeddings(self, text: str):
|
|
||||||
text = "[CLS]" + text + "[SEP]"
|
|
||||||
task_id = random.randint(0, 1000000000)
|
|
||||||
with self.lock:
|
|
||||||
while task_id in self.task_map:
|
|
||||||
task_id = random.randint(0, 1000000000)
|
|
||||||
|
|
||||||
task_info = {
|
|
||||||
"task_id": task_id,
|
|
||||||
"text": text,
|
|
||||||
"embedding": None
|
|
||||||
}
|
|
||||||
self.task_map[task_id] = task_info
|
|
||||||
self.task_list.append(task_info)
|
|
||||||
|
|
||||||
self.start_queue()
|
|
||||||
while True:
|
|
||||||
task_info = self.pop_task(task_id)
|
|
||||||
if task_info is not None:
|
|
||||||
return task_info["embedding"]
|
|
||||||
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
def pop_task(self, task_id):
|
|
||||||
with self.lock:
|
|
||||||
if task_id in self.task_map:
|
|
||||||
task_info = self.task_map[task_id]
|
|
||||||
if task_info["embedding"] is not None:
|
|
||||||
del self.task_map[task_id]
|
|
||||||
return task_info
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
running = True
|
|
||||||
last_task_time = None
|
|
||||||
while running and self.running:
|
|
||||||
current_time = time.time()
|
|
||||||
task = None
|
|
||||||
with self.lock:
|
|
||||||
if len(self.task_list) > 0:
|
|
||||||
task = self.task_list.pop(0)
|
|
||||||
|
|
||||||
if task is not None:
|
|
||||||
embeddings = self.embedding_model(task["text"])
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
task["embedding"] = embeddings[0][1]
|
|
||||||
|
|
||||||
last_task_time = time.time()
|
|
||||||
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
|
|
||||||
self.thread = None
|
|
||||||
self.running = False
|
|
||||||
running = False
|
|
||||||
else:
|
|
||||||
time.sleep(0.01)
|
|
||||||
|
|
||||||
def start_queue(self):
|
|
||||||
if not self.running:
|
|
||||||
self.running = True
|
|
||||||
self.thread = threading.Thread(target=self.run)
|
|
||||||
self.thread.start()
|
|
||||||
|
|
||||||
bert_embedding_queue = BERTEmbeddingQueue()
|
|
||||||
bert_embedding_queue.init()
|
|
||||||
|
|
||||||
class BERTEmbeddingService:
|
|
||||||
instance = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def create() -> BERTEmbeddingService:
|
|
||||||
if BERTEmbeddingService.instance is None:
|
|
||||||
BERTEmbeddingService.instance = BERTEmbeddingService()
|
|
||||||
await BERTEmbeddingService.instance.init()
|
|
||||||
return BERTEmbeddingService.instance
|
|
||||||
|
|
||||||
async def init(self):
|
|
||||||
self.tiktoken = await TikTokenService.create()
|
|
||||||
self.embedding_queue = BERTEmbeddingQueue()
|
|
||||||
await loop.run_in_executor(None, self.embedding_queue.init)
|
|
||||||
|
|
||||||
async def get_embeddings(self, docs, on_progress=None):
|
|
||||||
if len(docs) == 0:
|
|
||||||
return ([], 0)
|
|
||||||
|
|
||||||
if on_progress is not None:
|
|
||||||
await on_progress(0, len(docs))
|
|
||||||
|
|
||||||
embeddings = []
|
|
||||||
token_usage = 0
|
|
||||||
|
|
||||||
for doc in docs:
|
|
||||||
if "text" in doc:
|
|
||||||
tokens = await self.tiktoken.get_tokens(doc["text"])
|
|
||||||
token_usage += tokens
|
|
||||||
embeddings.append({
|
|
||||||
"id": doc["id"],
|
|
||||||
"text": doc["text"],
|
|
||||||
"embedding": self.model.encode(doc["text"]),
|
|
||||||
"tokens": tokens
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
embeddings.append({
|
|
||||||
"id": doc["id"],
|
|
||||||
"text": doc["text"],
|
|
||||||
"embedding": None,
|
|
||||||
"tokens": 0
|
|
||||||
})
|
|
||||||
|
|
||||||
if on_progress is not None:
|
|
||||||
await on_progress(1, len(docs))
|
|
||||||
|
|
||||||
return (embeddings, token_usage)
|
|
@ -0,0 +1,152 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from copy import deepcopy
|
||||||
|
import time
|
||||||
|
from lib.config import Config
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
from typing import Callable, Optional, TypedDict
|
||||||
|
import torch
|
||||||
|
from text2vec import SentenceModel
|
||||||
|
from utils.local import loop
|
||||||
|
from service.openai_api import OpenAIApi
|
||||||
|
from service.tiktoken import TikTokenService
|
||||||
|
|
||||||
|
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
|
||||||
|
|
||||||
|
class Text2VecEmbeddingQueueTaskInfo(TypedDict):
|
||||||
|
task_id: int
|
||||||
|
text: str
|
||||||
|
embedding: torch.Tensor
|
||||||
|
|
||||||
|
class Text2VecEmbeddingQueue:
|
||||||
|
def __init__(self, model: str) -> None:
|
||||||
|
self.model_name = model
|
||||||
|
|
||||||
|
self.embedding_model = SentenceModel(self.model_name)
|
||||||
|
self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {}
|
||||||
|
self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = []
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
self.thread: Optional[threading.Thread] = None
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embeddings(self, text: str):
|
||||||
|
task_id = random.randint(0, 1000000000)
|
||||||
|
with self.lock:
|
||||||
|
while task_id in self.task_map:
|
||||||
|
task_id = random.randint(0, 1000000000)
|
||||||
|
|
||||||
|
task_info = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"text": text,
|
||||||
|
"embedding": None
|
||||||
|
}
|
||||||
|
self.task_map[task_id] = task_info
|
||||||
|
self.task_list.append(task_info)
|
||||||
|
|
||||||
|
self.start_queue()
|
||||||
|
while True:
|
||||||
|
task_info = self.pop_task(task_id)
|
||||||
|
if task_info is not None:
|
||||||
|
return task_info["embedding"]
|
||||||
|
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
def pop_task(self, task_id):
|
||||||
|
with self.lock:
|
||||||
|
if task_id in self.task_map:
|
||||||
|
task_info = self.task_map[task_id]
|
||||||
|
if task_info["embedding"] is not None:
|
||||||
|
del self.task_map[task_id]
|
||||||
|
return task_info
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
running = True
|
||||||
|
last_task_time = None
|
||||||
|
while running and self.running:
|
||||||
|
current_time = time.time()
|
||||||
|
task = None
|
||||||
|
with self.lock:
|
||||||
|
if len(self.task_list) > 0:
|
||||||
|
task = self.task_list.pop(0)
|
||||||
|
|
||||||
|
if task is not None:
|
||||||
|
embeddings = self.embedding_model.encode([task["text"]])
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
task["embedding"] = embeddings[0]
|
||||||
|
|
||||||
|
last_task_time = time.time()
|
||||||
|
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
|
||||||
|
self.thread = None
|
||||||
|
self.running = False
|
||||||
|
running = False
|
||||||
|
else:
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def start_queue(self):
|
||||||
|
if not self.running:
|
||||||
|
self.running = True
|
||||||
|
self.thread = threading.Thread(target=self.run)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
class TextEmbeddingService:
|
||||||
|
instance = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create() -> TextEmbeddingService:
|
||||||
|
if TextEmbeddingService.instance is None:
|
||||||
|
TextEmbeddingService.instance = TextEmbeddingService()
|
||||||
|
await TextEmbeddingService.instance.init()
|
||||||
|
return TextEmbeddingService.instance
|
||||||
|
|
||||||
|
async def init(self):
|
||||||
|
self.tiktoken = await TikTokenService.create()
|
||||||
|
|
||||||
|
self.embedding_type = Config.get("embedding.type", "text2vec")
|
||||||
|
|
||||||
|
if self.embedding_type == "text2vec":
|
||||||
|
embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese")
|
||||||
|
self.embedding_queue = Text2VecEmbeddingQueue(model=embedding_model)
|
||||||
|
elif self.embedding_type == "openai":
|
||||||
|
self.openai_api: OpenAIApi = await OpenAIApi.create()
|
||||||
|
|
||||||
|
await loop.run_in_executor(None, self.embedding_queue.init)
|
||||||
|
|
||||||
|
async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
|
||||||
|
for index, doc in enumerate(doc_list):
|
||||||
|
text = doc["text"]
|
||||||
|
embedding = await self.embedding_queue.get_embeddings(text)
|
||||||
|
doc["embedding"] = embedding
|
||||||
|
|
||||||
|
if on_index_progress is not None:
|
||||||
|
await on_index_progress(index, len(doc_list))
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
|
||||||
|
res_doc_list = deepcopy(doc_list)
|
||||||
|
|
||||||
|
regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
|
||||||
|
for doc in res_doc_list:
|
||||||
|
text: str = doc["text"]
|
||||||
|
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
||||||
|
if "\n" in text:
|
||||||
|
lines = text.split("\n")
|
||||||
|
new_lines = []
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
# Add a dot at the end of the line if it doesn't end with a punctuation mark
|
||||||
|
if len(line) > 0 and regex.find(line[-1]) == -1:
|
||||||
|
line += "."
|
||||||
|
new_lines.append(line)
|
||||||
|
text = " ".join(new_lines)
|
||||||
|
doc["text"] = text
|
||||||
|
|
||||||
|
if self.embedding_type == "text2vec":
|
||||||
|
return await self.get_text2vec_embeddings(res_doc_list, on_index_progress)
|
||||||
|
elif self.embedding_type == "openai":
|
||||||
|
return await self.openai_api.get_embeddings(res_doc_list, on_index_progress)
|
@ -0,0 +1,6 @@
|
|||||||
|
import asyncio
|
||||||
|
from lib.noawait import NoAwaitPool
|
||||||
|
|
||||||
|
debug = False
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
noawait = NoAwaitPool(loop)
|
Loading…
Reference in New Issue