完成ChatComplete重构
parent
2a1e5c1589
commit
1355dbbf35
@ -1,56 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import TypeVar
|
||||
import toml
|
||||
|
||||
class Config:
|
||||
values: dict = {}
|
||||
|
||||
@staticmethod
|
||||
def load_config(file):
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
Config.values = toml.load(f)
|
||||
|
||||
@staticmethod
|
||||
def get(key: str, default=None, type=None, empty_is_none=False):
|
||||
key_path = key.split(".")
|
||||
value = Config.values
|
||||
for k in key_path:
|
||||
if k in value:
|
||||
value = value[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
if empty_is_none and value == "":
|
||||
return None
|
||||
|
||||
if type == bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
elif isinstance(value, int) or isinstance(value, float):
|
||||
return value != 0
|
||||
else:
|
||||
return str(value).lower() in ("yes", "true", "1")
|
||||
elif type == int:
|
||||
return int(value)
|
||||
elif type == float:
|
||||
return float(value)
|
||||
elif type == str:
|
||||
return str(value)
|
||||
elif type == list:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
elif type == dict:
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
else:
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def set(key: str, value):
|
||||
key_path = key.split(".")
|
||||
obj = Config.values
|
||||
for k in key_path[:-1]:
|
||||
if k not in obj:
|
||||
obj[k] = {}
|
||||
obj = obj[k]
|
||||
obj[key_path[-1]] = value
|
@ -1,6 +0,0 @@
|
||||
import asyncio
|
||||
from noawait import NoAwaitPool
|
||||
|
||||
debug = False
|
||||
loop = asyncio.new_event_loop()
|
||||
noawait = NoAwaitPool(loop)
|
@ -1,165 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from asyncio import AbstractEventLoop, Task
|
||||
import asyncio
|
||||
import atexit
|
||||
from functools import wraps
|
||||
import random
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Callable, Coroutine, Optional, TypedDict
|
||||
|
||||
class TimerInfo(TypedDict):
|
||||
id: int
|
||||
callback: Callable
|
||||
interval: float
|
||||
next_time: float
|
||||
|
||||
class NoAwaitPool:
|
||||
def __init__(self, loop: AbstractEventLoop):
|
||||
self.task_list: list[Task] = []
|
||||
self.timer_map: dict[int, TimerInfo] = {}
|
||||
self.loop = loop
|
||||
self.running = True
|
||||
|
||||
self.should_refresh_task = False
|
||||
self.next_timer_time: Optional[float] = None
|
||||
|
||||
self.on_error: list[Callable] = []
|
||||
|
||||
self.gc_task = loop.create_task(self._run_gc())
|
||||
self.timer_task = loop.create_task(self._run_timer())
|
||||
|
||||
atexit.register(self.end_task)
|
||||
|
||||
async def end(self):
|
||||
if self.running:
|
||||
print("Stopping NoAwait Tasks...")
|
||||
self.running = False
|
||||
for task in self.task_list:
|
||||
await self._finish_task(task)
|
||||
|
||||
await self.gc_task
|
||||
await self.timer_task
|
||||
|
||||
def end_task(self):
|
||||
if self.running and not self.loop.is_closed():
|
||||
self.loop.run_until_complete(self.end())
|
||||
|
||||
async def _wrap_task(self, task: Task):
|
||||
try:
|
||||
await task
|
||||
except Exception as e:
|
||||
handled = False
|
||||
for handler in self.on_error:
|
||||
try:
|
||||
handler_ret = handler(e)
|
||||
await handler_ret
|
||||
handled = True
|
||||
except Exception as handler_err:
|
||||
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
if not handled:
|
||||
print(e, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
self.should_refresh_task = True
|
||||
|
||||
def add_task(self, coroutine: Coroutine):
|
||||
task = self.loop.create_task(coroutine)
|
||||
self.task_list.append(task)
|
||||
|
||||
def add_timer(self, callback: Callable, interval: float) -> int:
|
||||
id = random.randint(0, 1000000000)
|
||||
while id in self.timer_map:
|
||||
id = random.randint(0, 1000000000)
|
||||
|
||||
now = self.loop.time()
|
||||
next_time = now + interval
|
||||
self.timer_map[id] = {
|
||||
"id": id,
|
||||
"callback": callback,
|
||||
"interval": interval,
|
||||
"next_time": next_time
|
||||
}
|
||||
|
||||
if self.next_timer_time is None or next_time < self.next_timer_time:
|
||||
self.next_timer_time = next_time
|
||||
|
||||
return id
|
||||
|
||||
def remove_timer(self, id: int):
|
||||
if id in self.timer_map:
|
||||
del self.timer_map[id]
|
||||
|
||||
def wrap(self, f):
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
coroutine = f(*args, **kwargs)
|
||||
self.add_task(coroutine)
|
||||
|
||||
return decorated_function
|
||||
|
||||
async def _finish_task(self, task: Task):
|
||||
try:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
await task
|
||||
except Exception as e:
|
||||
handled = False
|
||||
for handler in self.on_error:
|
||||
try:
|
||||
handler_ret = handler(e)
|
||||
await handler_ret
|
||||
handled = True
|
||||
except Exception as handler_err:
|
||||
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
if not handled:
|
||||
print(e, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
async def _run_gc(self):
|
||||
while self.running:
|
||||
if self.should_refresh_task:
|
||||
should_remove = []
|
||||
for task in self.task_list:
|
||||
if task.done():
|
||||
await self._finish_task(task)
|
||||
should_remove.append(task)
|
||||
for task in should_remove:
|
||||
self.task_list.remove(task)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _run_timer(self):
|
||||
while self.running:
|
||||
now = self.loop.time()
|
||||
if self.next_timer_time is not None and now >= self.next_timer_time:
|
||||
self.next_timer_time = None
|
||||
for timer in self.timer_map.values():
|
||||
if now >= timer["next_time"]:
|
||||
timer["next_time"] = now + timer["interval"]
|
||||
try:
|
||||
result = timer["callback"]()
|
||||
self.add_task(result)
|
||||
except Exception as e:
|
||||
handled = False
|
||||
for handler in self.on_error:
|
||||
try:
|
||||
handler_ret = handler(e)
|
||||
self.add_task(handler_ret)
|
||||
handled = True
|
||||
except Exception as handler_err:
|
||||
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
if not handled:
|
||||
print(e, file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
if self.next_timer_time is None or timer["next_time"] < self.next_timer_time:
|
||||
self.next_timer_time = timer["next_time"]
|
||||
|
||||
await asyncio.sleep(0.1)
|
@ -1,5 +1,3 @@
|
||||
transformers
|
||||
text2vec>=1.2.9
|
||||
--index-url https://download.pytorch.org/whl/cpu
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
@ -1,143 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
from config import Config
|
||||
import asyncio
|
||||
import random
|
||||
import threading
|
||||
from typing import Optional, TypedDict
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
from local import loop
|
||||
|
||||
from service.tiktoken import TikTokenService
|
||||
|
||||
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
|
||||
|
||||
class BERTEmbeddingQueueTaskInfo(TypedDict):
|
||||
task_id: int
|
||||
text: str
|
||||
embedding: torch.Tensor
|
||||
|
||||
class BERTEmbeddingQueue:
|
||||
def init(self):
|
||||
self.embedding_model = pipeline("feature-extraction", model="bert-base-chinese")
|
||||
self.task_map: dict[int, BERTEmbeddingQueueTaskInfo] = {}
|
||||
self.task_list: list[BERTEmbeddingQueueTaskInfo] = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
self.thread: Optional[threading.Thread] = None
|
||||
self.running = False
|
||||
|
||||
async def get_embeddings(self, text: str):
|
||||
text = "[CLS]" + text + "[SEP]"
|
||||
task_id = random.randint(0, 1000000000)
|
||||
with self.lock:
|
||||
while task_id in self.task_map:
|
||||
task_id = random.randint(0, 1000000000)
|
||||
|
||||
task_info = {
|
||||
"task_id": task_id,
|
||||
"text": text,
|
||||
"embedding": None
|
||||
}
|
||||
self.task_map[task_id] = task_info
|
||||
self.task_list.append(task_info)
|
||||
|
||||
self.start_queue()
|
||||
while True:
|
||||
task_info = self.pop_task(task_id)
|
||||
if task_info is not None:
|
||||
return task_info["embedding"]
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
def pop_task(self, task_id):
|
||||
with self.lock:
|
||||
if task_id in self.task_map:
|
||||
task_info = self.task_map[task_id]
|
||||
if task_info["embedding"] is not None:
|
||||
del self.task_map[task_id]
|
||||
return task_info
|
||||
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
running = True
|
||||
last_task_time = None
|
||||
while running and self.running:
|
||||
current_time = time.time()
|
||||
task = None
|
||||
with self.lock:
|
||||
if len(self.task_list) > 0:
|
||||
task = self.task_list.pop(0)
|
||||
|
||||
if task is not None:
|
||||
embeddings = self.embedding_model(task["text"])
|
||||
|
||||
with self.lock:
|
||||
task["embedding"] = embeddings[0][1]
|
||||
|
||||
last_task_time = time.time()
|
||||
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
|
||||
self.thread = None
|
||||
self.running = False
|
||||
running = False
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
def start_queue(self):
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self.run)
|
||||
self.thread.start()
|
||||
|
||||
bert_embedding_queue = BERTEmbeddingQueue()
|
||||
bert_embedding_queue.init()
|
||||
|
||||
class BERTEmbeddingService:
|
||||
instance = None
|
||||
|
||||
@staticmethod
|
||||
async def create() -> BERTEmbeddingService:
|
||||
if BERTEmbeddingService.instance is None:
|
||||
BERTEmbeddingService.instance = BERTEmbeddingService()
|
||||
await BERTEmbeddingService.instance.init()
|
||||
return BERTEmbeddingService.instance
|
||||
|
||||
async def init(self):
|
||||
self.tiktoken = await TikTokenService.create()
|
||||
self.embedding_queue = BERTEmbeddingQueue()
|
||||
await loop.run_in_executor(None, self.embedding_queue.init)
|
||||
|
||||
async def get_embeddings(self, docs, on_progress=None):
|
||||
if len(docs) == 0:
|
||||
return ([], 0)
|
||||
|
||||
if on_progress is not None:
|
||||
await on_progress(0, len(docs))
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
for doc in docs:
|
||||
if "text" in doc:
|
||||
tokens = await self.tiktoken.get_tokens(doc["text"])
|
||||
token_usage += tokens
|
||||
embeddings.append({
|
||||
"id": doc["id"],
|
||||
"text": doc["text"],
|
||||
"embedding": self.model.encode(doc["text"]),
|
||||
"tokens": tokens
|
||||
})
|
||||
else:
|
||||
embeddings.append({
|
||||
"id": doc["id"],
|
||||
"text": doc["text"],
|
||||
"embedding": None,
|
||||
"tokens": 0
|
||||
})
|
||||
|
||||
if on_progress is not None:
|
||||
await on_progress(1, len(docs))
|
||||
|
||||
return (embeddings, token_usage)
|
@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
from copy import deepcopy
|
||||
import time
|
||||
from lib.config import Config
|
||||
import asyncio
|
||||
import random
|
||||
import threading
|
||||
from typing import Callable, Optional, TypedDict
|
||||
import torch
|
||||
from text2vec import SentenceModel
|
||||
from utils.local import loop
|
||||
from service.openai_api import OpenAIApi
|
||||
from service.tiktoken import TikTokenService
|
||||
|
||||
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
|
||||
|
||||
class Text2VecEmbeddingQueueTaskInfo(TypedDict):
|
||||
task_id: int
|
||||
text: str
|
||||
embedding: torch.Tensor
|
||||
|
||||
class Text2VecEmbeddingQueue:
|
||||
def __init__(self, model: str) -> None:
|
||||
self.model_name = model
|
||||
|
||||
self.embedding_model = SentenceModel(self.model_name)
|
||||
self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {}
|
||||
self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
self.thread: Optional[threading.Thread] = None
|
||||
self.running = False
|
||||
|
||||
|
||||
async def get_embeddings(self, text: str):
|
||||
task_id = random.randint(0, 1000000000)
|
||||
with self.lock:
|
||||
while task_id in self.task_map:
|
||||
task_id = random.randint(0, 1000000000)
|
||||
|
||||
task_info = {
|
||||
"task_id": task_id,
|
||||
"text": text,
|
||||
"embedding": None
|
||||
}
|
||||
self.task_map[task_id] = task_info
|
||||
self.task_list.append(task_info)
|
||||
|
||||
self.start_queue()
|
||||
while True:
|
||||
task_info = self.pop_task(task_id)
|
||||
if task_info is not None:
|
||||
return task_info["embedding"]
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
def pop_task(self, task_id):
|
||||
with self.lock:
|
||||
if task_id in self.task_map:
|
||||
task_info = self.task_map[task_id]
|
||||
if task_info["embedding"] is not None:
|
||||
del self.task_map[task_id]
|
||||
return task_info
|
||||
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
running = True
|
||||
last_task_time = None
|
||||
while running and self.running:
|
||||
current_time = time.time()
|
||||
task = None
|
||||
with self.lock:
|
||||
if len(self.task_list) > 0:
|
||||
task = self.task_list.pop(0)
|
||||
|
||||
if task is not None:
|
||||
embeddings = self.embedding_model.encode([task["text"]])
|
||||
|
||||
with self.lock:
|
||||
task["embedding"] = embeddings[0]
|
||||
|
||||
last_task_time = time.time()
|
||||
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
|
||||
self.thread = None
|
||||
self.running = False
|
||||
running = False
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
def start_queue(self):
|
||||
if not self.running:
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self.run)
|
||||
self.thread.start()
|
||||
|
||||
class TextEmbeddingService:
|
||||
instance = None
|
||||
|
||||
@staticmethod
|
||||
async def create() -> TextEmbeddingService:
|
||||
if TextEmbeddingService.instance is None:
|
||||
TextEmbeddingService.instance = TextEmbeddingService()
|
||||
await TextEmbeddingService.instance.init()
|
||||
return TextEmbeddingService.instance
|
||||
|
||||
async def init(self):
|
||||
self.tiktoken = await TikTokenService.create()
|
||||
|
||||
self.embedding_type = Config.get("embedding.type", "text2vec")
|
||||
|
||||
if self.embedding_type == "text2vec":
|
||||
embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese")
|
||||
self.embedding_queue = Text2VecEmbeddingQueue(model=embedding_model)
|
||||
elif self.embedding_type == "openai":
|
||||
self.openai_api: OpenAIApi = await OpenAIApi.create()
|
||||
|
||||
await loop.run_in_executor(None, self.embedding_queue.init)
|
||||
|
||||
async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
|
||||
for index, doc in enumerate(doc_list):
|
||||
text = doc["text"]
|
||||
embedding = await self.embedding_queue.get_embeddings(text)
|
||||
doc["embedding"] = embedding
|
||||
|
||||
if on_index_progress is not None:
|
||||
await on_index_progress(index, len(doc_list))
|
||||
|
||||
|
||||
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
|
||||
res_doc_list = deepcopy(doc_list)
|
||||
|
||||
regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
|
||||
for doc in res_doc_list:
|
||||
text: str = doc["text"]
|
||||
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
||||
if "\n" in text:
|
||||
lines = text.split("\n")
|
||||
new_lines = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
# Add a dot at the end of the line if it doesn't end with a punctuation mark
|
||||
if len(line) > 0 and regex.find(line[-1]) == -1:
|
||||
line += "."
|
||||
new_lines.append(line)
|
||||
text = " ".join(new_lines)
|
||||
doc["text"] = text
|
||||
|
||||
if self.embedding_type == "text2vec":
|
||||
return await self.get_text2vec_embeddings(res_doc_list, on_index_progress)
|
||||
elif self.embedding_type == "openai":
|
||||
return await self.openai_api.get_embeddings(res_doc_list, on_index_progress)
|
@ -0,0 +1,6 @@
|
||||
import asyncio
|
||||
from lib.noawait import NoAwaitPool
|
||||
|
||||
debug = False
|
||||
loop = asyncio.new_event_loop()
|
||||
noawait = NoAwaitPool(loop)
|
Loading…
Reference in New Issue