完成ChatComplete重构

master
落雨楓 9 months ago
parent 2a1e5c1589
commit 1355dbbf35

@ -6,7 +6,7 @@ from api.controller.task.ChatCompleteTask import ChatCompleteTask
from api.model.base import clone_model from api.model.base import clone_model
from api.model.chat_complete.bot_persona import BotPersonaHelper from api.model.chat_complete.bot_persona import BotPersonaHelper
from api.model.toolkit_ui.conversation import ConversationHelper from api.model.toolkit_ui.conversation import ConversationHelper
from local import noawait from utils.local import noawait
from aiohttp import web from aiohttp import web
from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse

@ -2,7 +2,7 @@ from __future__ import annotations
import sys import sys
import time import time
import traceback import traceback
from local import noawait from utils.local import noawait
from typing import Optional, Callable, Union from typing import Optional, Callable, Union
from service.chat_complete import ( from service.chat_complete import (
ChatCompleteService, ChatCompleteService,

@ -28,6 +28,7 @@ class BotPersonaModel(BaseModel):
), ),
index=True, index=True,
) )
api_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String) system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON) message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)

@ -6,7 +6,7 @@ import asyncpg
from api.model.base import BaseModel from api.model.base import BaseModel
import numpy as np import numpy as np
import sqlalchemy import sqlalchemy
from config import Config from lib.config import Config
from sqlalchemy import Index, select, update, delete, Select from sqlalchemy import Index, select, update, delete, Select
from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.orm import mapped_column, Mapped
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -17,7 +17,7 @@ from service.database import DatabaseService
page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {} page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}
embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int) embedding_vector_size = Config.get("embedding.vector_size", 1536, int)
class AbstractPageIndexModel(BaseModel): class AbstractPageIndexModel(BaseModel):
__abstract__ = True __abstract__ = True

@ -8,11 +8,11 @@ import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer
from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncEngine
from config import Config from lib.config import Config
from api.model.base import BaseHelper, BaseModel from api.model.base import BaseHelper, BaseModel
from service.database import DatabaseService from service.database import DatabaseService
embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int) embedding_vector_size = Config.get("embedding.vector_size", 1536, int)
class TitleIndexModel(BaseModel): class TitleIndexModel(BaseModel):
__tablename__ = "embedding_search_title_index" __tablename__ = "embedding_search_title_index"

@ -1,56 +0,0 @@
from __future__ import annotations
from typing import TypeVar
import toml
class Config:
values: dict = {}
@staticmethod
def load_config(file):
with open(file, "r", encoding="utf-8") as f:
Config.values = toml.load(f)
@staticmethod
def get(key: str, default=None, type=None, empty_is_none=False):
key_path = key.split(".")
value = Config.values
for k in key_path:
if k in value:
value = value[k]
else:
return default
if empty_is_none and value == "":
return None
if type == bool:
if isinstance(value, bool):
return value
elif isinstance(value, int) or isinstance(value, float):
return value != 0
else:
return str(value).lower() in ("yes", "true", "1")
elif type == int:
return int(value)
elif type == float:
return float(value)
elif type == str:
return str(value)
elif type == list:
if not isinstance(value, list):
return []
elif type == dict:
if not isinstance(value, dict):
return {}
else:
return value
@staticmethod
def set(key: str, value):
key_path = key.split(".")
obj = Config.values
for k in key_path[:-1]:
if k not in obj:
obj[k] = {}
obj = obj[k]
obj[key_path[-1]] = value

@ -1,6 +0,0 @@
import asyncio
from noawait import NoAwaitPool
debug = False
loop = asyncio.new_event_loop()
noawait = NoAwaitPool(loop)

@ -1,12 +1,12 @@
import sys import sys
import traceback import traceback
from config import Config from lib.config import Config
Config.load_config("config.toml") Config.load_config("config.toml")
from local import loop, noawait from utils.local import loop, noawait
from aiohttp import web from aiohttp import web
import local import utils.local as local
import api.route import api.route
import utils.web import utils.web
from service.database import DatabaseService from service.database import DatabaseService

@ -1,7 +1,7 @@
import asyncio import asyncio
import sys import sys
import base as _ import base as _
import local import utils.local as local
from service.database import DatabaseService from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService from service.embedding_search import EmbeddingSearchService

@ -1,165 +0,0 @@
from __future__ import annotations
from asyncio import AbstractEventLoop, Task
import asyncio
import atexit
from functools import wraps
import random
import sys
import traceback
from typing import Callable, Coroutine, Optional, TypedDict
class TimerInfo(TypedDict):
id: int
callback: Callable
interval: float
next_time: float
class NoAwaitPool:
def __init__(self, loop: AbstractEventLoop):
self.task_list: list[Task] = []
self.timer_map: dict[int, TimerInfo] = {}
self.loop = loop
self.running = True
self.should_refresh_task = False
self.next_timer_time: Optional[float] = None
self.on_error: list[Callable] = []
self.gc_task = loop.create_task(self._run_gc())
self.timer_task = loop.create_task(self._run_timer())
atexit.register(self.end_task)
async def end(self):
if self.running:
print("Stopping NoAwait Tasks...")
self.running = False
for task in self.task_list:
await self._finish_task(task)
await self.gc_task
await self.timer_task
def end_task(self):
if self.running and not self.loop.is_closed():
self.loop.run_until_complete(self.end())
async def _wrap_task(self, task: Task):
try:
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
finally:
self.should_refresh_task = True
def add_task(self, coroutine: Coroutine):
task = self.loop.create_task(coroutine)
self.task_list.append(task)
def add_timer(self, callback: Callable, interval: float) -> int:
id = random.randint(0, 1000000000)
while id in self.timer_map:
id = random.randint(0, 1000000000)
now = self.loop.time()
next_time = now + interval
self.timer_map[id] = {
"id": id,
"callback": callback,
"interval": interval,
"next_time": next_time
}
if self.next_timer_time is None or next_time < self.next_timer_time:
self.next_timer_time = next_time
return id
def remove_timer(self, id: int):
if id in self.timer_map:
del self.timer_map[id]
def wrap(self, f):
@wraps(f)
def decorated_function(*args, **kwargs):
coroutine = f(*args, **kwargs)
self.add_task(coroutine)
return decorated_function
async def _finish_task(self, task: Task):
try:
if not task.done():
task.cancel()
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
async def _run_gc(self):
while self.running:
if self.should_refresh_task:
should_remove = []
for task in self.task_list:
if task.done():
await self._finish_task(task)
should_remove.append(task)
for task in should_remove:
self.task_list.remove(task)
await asyncio.sleep(0.1)
async def _run_timer(self):
while self.running:
now = self.loop.time()
if self.next_timer_time is not None and now >= self.next_timer_time:
self.next_timer_time = None
for timer in self.timer_map.values():
if now >= timer["next_time"]:
timer["next_time"] = now + timer["interval"]
try:
result = timer["callback"]()
self.add_task(result)
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
self.add_task(handler_ret)
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
if self.next_timer_time is None or timer["next_time"] < self.next_timer_time:
self.next_timer_time = timer["next_time"]
await asyncio.sleep(0.1)

@ -1,5 +1,3 @@
transformers text2vec>=1.2.9
--index-url https://download.pytorch.org/whl/cpu --index-url https://download.pytorch.org/whl/cpu
torch torch
torchvision
torchaudio

@ -1,143 +0,0 @@
from __future__ import annotations
import time
from config import Config
import asyncio
import random
import threading
from typing import Optional, TypedDict
import torch
from transformers import pipeline
from local import loop
from service.tiktoken import TikTokenService
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
class BERTEmbeddingQueueTaskInfo(TypedDict):
task_id: int
text: str
embedding: torch.Tensor
class BERTEmbeddingQueue:
def init(self):
self.embedding_model = pipeline("feature-extraction", model="bert-base-chinese")
self.task_map: dict[int, BERTEmbeddingQueueTaskInfo] = {}
self.task_list: list[BERTEmbeddingQueueTaskInfo] = []
self.lock = threading.Lock()
self.thread: Optional[threading.Thread] = None
self.running = False
async def get_embeddings(self, text: str):
text = "[CLS]" + text + "[SEP]"
task_id = random.randint(0, 1000000000)
with self.lock:
while task_id in self.task_map:
task_id = random.randint(0, 1000000000)
task_info = {
"task_id": task_id,
"text": text,
"embedding": None
}
self.task_map[task_id] = task_info
self.task_list.append(task_info)
self.start_queue()
while True:
task_info = self.pop_task(task_id)
if task_info is not None:
return task_info["embedding"]
await asyncio.sleep(0.01)
def pop_task(self, task_id):
with self.lock:
if task_id in self.task_map:
task_info = self.task_map[task_id]
if task_info["embedding"] is not None:
del self.task_map[task_id]
return task_info
return None
def run(self):
running = True
last_task_time = None
while running and self.running:
current_time = time.time()
task = None
with self.lock:
if len(self.task_list) > 0:
task = self.task_list.pop(0)
if task is not None:
embeddings = self.embedding_model(task["text"])
with self.lock:
task["embedding"] = embeddings[0][1]
last_task_time = time.time()
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
self.thread = None
self.running = False
running = False
else:
time.sleep(0.01)
def start_queue(self):
if not self.running:
self.running = True
self.thread = threading.Thread(target=self.run)
self.thread.start()
bert_embedding_queue = BERTEmbeddingQueue()
bert_embedding_queue.init()
class BERTEmbeddingService:
instance = None
@staticmethod
async def create() -> BERTEmbeddingService:
if BERTEmbeddingService.instance is None:
BERTEmbeddingService.instance = BERTEmbeddingService()
await BERTEmbeddingService.instance.init()
return BERTEmbeddingService.instance
async def init(self):
self.tiktoken = await TikTokenService.create()
self.embedding_queue = BERTEmbeddingQueue()
await loop.run_in_executor(None, self.embedding_queue.init)
async def get_embeddings(self, docs, on_progress=None):
if len(docs) == 0:
return ([], 0)
if on_progress is not None:
await on_progress(0, len(docs))
embeddings = []
token_usage = 0
for doc in docs:
if "text" in doc:
tokens = await self.tiktoken.get_tokens(doc["text"])
token_usage += tokens
embeddings.append({
"id": doc["id"],
"text": doc["text"],
"embedding": self.model.encode(doc["text"]),
"tokens": tokens
})
else:
embeddings.append({
"id": doc["id"],
"text": doc["text"],
"embedding": None,
"tokens": 0
})
if on_progress is not None:
await on_progress(1, len(docs))
return (embeddings, token_usage)

@ -11,7 +11,7 @@ from api.model.chat_complete.conversation import (
import sys import sys
from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel
from config import Config from lib.config import Config
import utils.config, utils.web import utils.config, utils.web
from aiohttp import web from aiohttp import web
@ -20,7 +20,7 @@ from sqlalchemy.orm.attributes import flag_modified
from service.database import DatabaseService from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs, EmbeddingSearchService from service.embedding_search import EmbeddingSearchArgs, EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi from service.openai_api import OpenAIApi, OpenAIApiTypeInvalidException
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
class ChatCompleteQuestionTooLongException(Exception): class ChatCompleteQuestionTooLongException(Exception):
@ -34,6 +34,7 @@ class ChatCompleteServicePrepareResponse(TypedDict):
question_tokens: int question_tokens: int
conversation_id: int conversation_id: int
chunk_id: int chunk_id: int
api_id: str
class ChatCompleteServiceResponse(TypedDict): class ChatCompleteServiceResponse(TypedDict):
message: str message: str
@ -60,12 +61,12 @@ class ChatCompleteService:
self.conversation_info: Optional[ConversationModel] = None self.conversation_info: Optional[ConversationModel] = None
self.conversation_chunk: Optional[ConversationChunkModel] = None self.conversation_chunk: Optional[ConversationChunkModel] = None
self.openai_api: OpenAIApi = None
self.tiktoken: TikTokenService = None self.tiktoken: TikTokenService = None
self.extract_doc: list = None self.extract_doc: list = None
self.mwapi = MediaWikiApi.create() self.mwapi = MediaWikiApi.create()
self.openai_api = OpenAIApi.create()
self.user_id = 0 self.user_id = 0
self.question = "" self.question = ""
@ -156,6 +157,12 @@ class ChatCompleteService:
else: else:
self.chat_system_prompt = bot_persona.system_prompt self.chat_system_prompt = bot_persona.system_prompt
default_api = Config.get("chatcomplete.default_api", None, str)
try:
self.openai_api = OpenAIApi.create(bot_persona.api_id or default_api)
except OpenAIApiTypeInvalidException:
self.openai_api = OpenAIApi.create(default_api)
self.conversation_chunk = None self.conversation_chunk = None
if self.conversation_info is not None: if self.conversation_info is not None:
chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id) chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id)

@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
import local import utils.local as local
from urllib.parse import quote_plus from urllib.parse import quote_plus
from aiohttp import web from aiohttp import web
import asyncpg import asyncpg
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from config import Config from lib.config import Config
def get_dsn(): def get_dsn():
db_conf = Config.get("database") db_conf = Config.get("database")

@ -11,6 +11,7 @@ from api.model.embedding_search.page_index import PageIndexHelper
from service.database import DatabaseService from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi from service.openai_api import OpenAIApi
from service.text_embedding import TextEmbeddingService
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
from utils.wiki import getWikiSentences from utils.wiki import getWikiSentences
@ -38,6 +39,7 @@ class EmbeddingSearchService:
self.title_collection_helper = TitleCollectionHelper(dbs) self.title_collection_helper = TitleCollectionHelper(dbs)
self.page_index: PageIndexHelper = None self.page_index: PageIndexHelper = None
self.text_embedding: TextEmbeddingService = None
self.tiktoken: TikTokenService = None self.tiktoken: TikTokenService = None
self.mwapi = MediaWikiApi.create() self.mwapi = MediaWikiApi.create()
@ -53,6 +55,7 @@ class EmbeddingSearchService:
self.unindexed_docs: list = None self.unindexed_docs: list = None
async def __aenter__(self): async def __aenter__(self):
self.text_embedding = await TextEmbeddingService.create()
self.tiktoken = await TikTokenService.create() self.tiktoken = await TikTokenService.create()
await self.title_index_helper.__aenter__() await self.title_index_helper.__aenter__()
@ -225,7 +228,7 @@ class EmbeddingSearchService:
await on_progress(indexed_docs, len(self.unindexed_docs)) await on_progress(indexed_docs, len(self.unindexed_docs))
async def embedding_doc(doc_chunk): async def embedding_doc(doc_chunk):
(doc_chunk, token_usage) = await self.openai_api.get_embeddings( (doc_chunk, token_usage) = await self.text_embedding.get_embeddings(
doc_chunk, on_embedding_progress doc_chunk, on_embedding_progress
) )
await self.page_index.index_doc(doc_chunk, self.page_id) await self.page_index.index_doc(doc_chunk, self.page_id)
@ -266,7 +269,7 @@ class EmbeddingSearchService:
# Update title embedding # Update title embedding
if await self.title_index.awaitable_attrs.embedding is None: if await self.title_index.awaitable_attrs.embedding is None:
doc_chunk = [{"text": self.title}] doc_chunk = [{"text": self.title}]
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk) (doc_chunk, token_usage) = await self.text_embedding.get_embeddings(doc_chunk)
total_token_usage += token_usage total_token_usage += token_usage
embedding = doc_chunk[0]["embedding"] embedding = doc_chunk[0]["embedding"]
@ -290,7 +293,7 @@ class EmbeddingSearchService:
raise Exception("Page index is not initialized") raise Exception("Page index is not initialized")
query_doc = [{"text": query}] query_doc = [{"text": query}]
query_doc, token_usage = await self.openai_api.get_embeddings(query_doc) query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc)
query_embedding = query_doc[0]["embedding"] query_embedding = query_doc[0]["embedding"]
if query_embedding is None: if query_embedding is None:

@ -4,7 +4,7 @@ import sys
import time import time
from typing import Optional, TypedDict from typing import Optional, TypedDict
import aiohttp import aiohttp
from config import Config from lib.config import Config
class MediaWikiApiException(Exception): class MediaWikiApiException(Exception):
def __init__(self, info: str, code: Optional[str] = None) -> None: def __init__(self, info: str, code: Optional[str] = None) -> None:

@ -4,7 +4,7 @@ import json
from typing import Callable, Optional, TypedDict from typing import Callable, Optional, TypedDict
import aiohttp import aiohttp
from config import Config from lib.config import Config
import numpy as np import numpy as np
from aiohttp_sse_client2 import client as sse_client from aiohttp_sse_client2 import client as sse_client
@ -13,6 +13,13 @@ from service.tiktoken import TikTokenService
AZURE_CHATCOMPLETE_API_VERSION = "2023-07-01-preview" AZURE_CHATCOMPLETE_API_VERSION = "2023-07-01-preview"
AZURE_EMBEDDING_API_VERSION = "2023-05-15" AZURE_EMBEDDING_API_VERSION = "2023-05-15"
class OpenAIApiTypeInvalidException(Exception):
def __init__(self, api_id: str):
self.api_id = api_id
def __str__(self):
return f"Invalid api_id: {self.api_id}"
class ChatCompleteMessageLog(TypedDict): class ChatCompleteMessageLog(TypedDict):
role: str role: str
content: str content: str
@ -26,19 +33,21 @@ class ChatCompleteResponse(TypedDict):
class OpenAIApi: class OpenAIApi:
@staticmethod @staticmethod
def create(): def create(api_id: str) -> OpenAIApi:
return OpenAIApi() return OpenAIApi(api_id)
def __init__(self): def __init__(self, api_id: str):
self.api_type = Config.get("chatcomplete.api_type", "openai", str) self.api_id = api_id
self.request_proxy = Config.get("chatcomplete.request_proxy", type=str, empty_is_none=True)
if self.api_type == "azure": self.api_type = Config.get(f"chatcomplete.{api_id}.api_type", None, str)
self.api_url = Config.get("chatcomplete.azure.api_endpoint", type=str)
self.api_key = Config.get("chatcomplete.azure.key", type=str) if self.api_type is None:
else: raise OpenAIApiTypeInvalidException(api_id)
self.api_url = Config.get("chatcomplete.openai.api_endpoint", type=str)
self.api_key = Config.get("chatcomplete.openai.key", type=str) self.request_proxy = Config.get(f"chatcomplete.{api_id}.request_proxy", type=str, empty_is_none=True)
self.api_url = Config.get(f"chatcomplete.{api_id}.api_endpoint", type=str)
self.api_key = Config.get(f"chatcomplete.{api_id}.key", type=str)
def build_header(self): def build_header(self):
if self.api_type == "azure": if self.api_type == "azure":
@ -56,7 +65,7 @@ class OpenAIApi:
def get_url(self, method: str): def get_url(self, method: str):
if self.api_type == "azure": if self.api_type == "azure":
deployments = Config.get("chatcomplete.azure.deployments") deployments = Config.get(f"chatcomplete.{self.api_id}.deployments")
if method == "chat/completions": if method == "chat/completions":
return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method
elif method == "embeddings": elif method == "embeddings":
@ -65,25 +74,10 @@ class OpenAIApi:
return self.api_url + "/v1/" + method return self.api_url + "/v1/" + method
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
text_list = []
regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
for doc in doc_list:
text: str = doc["text"]
text = text.replace("\r\n", "\n").replace("\r", "\n")
if "\n" in text:
lines = text.split("\n")
new_lines = []
for line in lines:
line = line.strip()
# Add a dot at the end of the line if it doesn't end with a punctuation mark
if len(line) > 0 and regex.find(line[-1]) == -1:
line += "."
new_lines.append(line)
text = " ".join(new_lines)
text_list.append(text)
token_usage = 0 token_usage = 0
text_list = [doc["text"] for doc in doc_list]
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
url = self.get_url("embeddings") url = self.get_url("embeddings")
params = {} params = {}

@ -0,0 +1,152 @@
from __future__ import annotations
from copy import deepcopy
import time
from lib.config import Config
import asyncio
import random
import threading
from typing import Callable, Optional, TypedDict
import torch
from text2vec import SentenceModel
from utils.local import loop
from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
class Text2VecEmbeddingQueueTaskInfo(TypedDict):
task_id: int
text: str
embedding: torch.Tensor
class Text2VecEmbeddingQueue:
def __init__(self, model: str) -> None:
self.model_name = model
self.embedding_model = SentenceModel(self.model_name)
self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {}
self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = []
self.lock = threading.Lock()
self.thread: Optional[threading.Thread] = None
self.running = False
async def get_embeddings(self, text: str):
task_id = random.randint(0, 1000000000)
with self.lock:
while task_id in self.task_map:
task_id = random.randint(0, 1000000000)
task_info = {
"task_id": task_id,
"text": text,
"embedding": None
}
self.task_map[task_id] = task_info
self.task_list.append(task_info)
self.start_queue()
while True:
task_info = self.pop_task(task_id)
if task_info is not None:
return task_info["embedding"]
await asyncio.sleep(0.01)
def pop_task(self, task_id):
with self.lock:
if task_id in self.task_map:
task_info = self.task_map[task_id]
if task_info["embedding"] is not None:
del self.task_map[task_id]
return task_info
return None
def run(self):
running = True
last_task_time = None
while running and self.running:
current_time = time.time()
task = None
with self.lock:
if len(self.task_list) > 0:
task = self.task_list.pop(0)
if task is not None:
embeddings = self.embedding_model.encode([task["text"]])
with self.lock:
task["embedding"] = embeddings[0]
last_task_time = time.time()
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
self.thread = None
self.running = False
running = False
else:
time.sleep(0.01)
def start_queue(self):
if not self.running:
self.running = True
self.thread = threading.Thread(target=self.run)
self.thread.start()
class TextEmbeddingService:
instance = None
@staticmethod
async def create() -> TextEmbeddingService:
if TextEmbeddingService.instance is None:
TextEmbeddingService.instance = TextEmbeddingService()
await TextEmbeddingService.instance.init()
return TextEmbeddingService.instance
async def init(self):
self.tiktoken = await TikTokenService.create()
self.embedding_type = Config.get("embedding.type", "text2vec")
if self.embedding_type == "text2vec":
embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese")
self.embedding_queue = Text2VecEmbeddingQueue(model=embedding_model)
elif self.embedding_type == "openai":
self.openai_api: OpenAIApi = await OpenAIApi.create()
await loop.run_in_executor(None, self.embedding_queue.init)
async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
for index, doc in enumerate(doc_list):
text = doc["text"]
embedding = await self.embedding_queue.get_embeddings(text)
doc["embedding"] = embedding
if on_index_progress is not None:
await on_index_progress(index, len(doc_list))
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
res_doc_list = deepcopy(doc_list)
regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
for doc in res_doc_list:
text: str = doc["text"]
text = text.replace("\r\n", "\n").replace("\r", "\n")
if "\n" in text:
lines = text.split("\n")
new_lines = []
for line in lines:
line = line.strip()
# Add a dot at the end of the line if it doesn't end with a punctuation mark
if len(line) > 0 and regex.find(line[-1]) == -1:
line += "."
new_lines.append(line)
text = " ".join(new_lines)
doc["text"] = text
if self.embedding_type == "text2vec":
return await self.get_text2vec_embeddings(res_doc_list, on_index_progress)
elif self.embedding_type == "openai":
return await self.openai_api.get_embeddings(res_doc_list, on_index_progress)

@ -2,9 +2,9 @@ import sys
import pathlib import pathlib
root_path = pathlib.Path(__file__).parent.parent root_path = pathlib.Path(__file__).parent.parent
sys.path.append(root_path) sys.path.append(".")
from config import Config from lib.config import Config
Config.load_config(root_path + "/config.toml") Config.load_config(str(root_path) + "/config.toml")
Config.set("server.debug", True) Config.set("server.debug", True)

@ -1,6 +1,6 @@
import traceback import traceback
import base import base
import local import utils.local as local
from service.chat_complete import ChatCompleteService from service.chat_complete import ChatCompleteService
from service.database import DatabaseService from service.database import DatabaseService

@ -3,7 +3,7 @@ import base
from sqlalchemy import select from sqlalchemy import select
from api.model.embedding_search.title_index import TitleIndexModel from api.model.embedding_search.title_index import TitleIndexModel
import local import utils.local as local
from service.database import DatabaseService from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService from service.embedding_search import EmbeddingSearchService

@ -1,5 +1,5 @@
import base import base
import local import utils.local as local
from service.database import DatabaseService from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService from service.embedding_search import EmbeddingSearchService

@ -1,13 +1,21 @@
import asyncio import asyncio
import time import time
import base import base as _
from local import loop, noawait from utils.local import loop, noawait
from service.bert_embedding import bert_embedding_queue from service.text_embedding import Text2VecEmbeddingQueue
async def main(): async def main():
embedding_list = [] embedding_list = []
start_time = time.time()
queue = [] queue = []
text2vec_queue = Text2VecEmbeddingQueue("shibing624/text2vec-base-chinese")
start_time = time.time()
async def on_progress(current, total):
print(f"{current}/{total}")
total_lines = 0
with open("test/test.md", "r", encoding="utf-8") as fp: with open("test/test.md", "r", encoding="utf-8") as fp:
text = fp.read() text = fp.read()
lines = text.split("\n") lines = text.split("\n")
@ -16,10 +24,13 @@ async def main():
if line == "": if line == "":
continue continue
queue.append(bert_embedding_queue.get_embeddings(line)) queue.append(text2vec_queue.get_embeddings(line))
total_lines += 0
embedding_list = await asyncio.gather(*queue) embedding_list = await asyncio.gather(*queue)
end_time = time.time() end_time = time.time()
print("total lines: %d" % total_lines)
print("time cost: %.4f" % (end_time - start_time)) print("time cost: %.4f" % (end_time - start_time))
print("speed: %.4f it/s" % (total_lines / (end_time - start_time)))
print("dimensions: %d" % len(embedding_list[0])) print("dimensions: %d" % len(embedding_list[0]))
await noawait.end() await noawait.end()

@ -1,6 +1,6 @@
import asyncio import asyncio
import base import base
from local import loop, noawait from utils.local import loop, noawait
async def test_timer1(): async def test_timer1():
print("timer1") print("timer1")

@ -3,7 +3,7 @@ import base
from sqlalchemy import select from sqlalchemy import select
from api.model.embedding_search.title_index import TitleIndexModel from api.model.embedding_search.title_index import TitleIndexModel
import local import utils.local as local
from service.database import DatabaseService from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService from service.embedding_search import EmbeddingSearchService

@ -1,5 +1,5 @@
import time import time
from config import Config from lib.config import Config
def get_prompt(name: str, type: str, params: dict = {}): def get_prompt(name: str, type: str, params: dict = {}):
sys_params = { sys_params = {

@ -0,0 +1,6 @@
import asyncio
from lib.noawait import NoAwaitPool
debug = False
loop = asyncio.new_event_loop()
noawait = NoAwaitPool(loop)

@ -5,7 +5,7 @@ from typing import Any, Optional, Dict
from aiohttp import web from aiohttp import web
import jwt import jwt
import uuid import uuid
from config import Config from lib.config import Config
ParamRule = Dict[str, Any] ParamRule = Dict[str, Any]

Loading…
Cancel
Save