From 1355dbbf35b6bab4469abb6cb61468f5f6d9a5b0 Mon Sep 17 00:00:00 2001 From: Lex Lim Date: Mon, 6 May 2024 18:45:52 +0000 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90ChatComplete=E9=87=8D?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/controller/ChatComplete.py | 2 +- api/controller/task/ChatCompleteTask.py | 2 +- api/model/chat_complete/bot_persona.py | 1 + api/model/embedding_search/page_index.py | 4 +- api/model/embedding_search/title_index.py | 4 +- config.py | 56 ------ local.py | 6 - main.py | 6 +- maintenance/update_title_index.py | 2 +- noawait.py | 165 ------------------ .../embedding.txt | 6 +- service/bert_embedding.py | 143 --------------- service/chat_complete.py | 13 +- service/database.py | 4 +- service/embedding_search.py | 9 +- service/mediawiki_api.py | 2 +- service/openai_api.py | 54 +++--- service/text_embedding.py | 152 ++++++++++++++++ test/base.py | 6 +- test/chatcomplete.py | 2 +- test/create_token.py | 2 +- test/embedding_search.py | 2 +- ...g_queue.py => text2vec_embedding_queue.py} | 21 ++- test/timer.py | 2 +- test/title_index.py | 2 +- utils/config.py | 2 +- utils/local.py | 6 + utils/web.py | 2 +- 28 files changed, 240 insertions(+), 438 deletions(-) delete mode 100644 config.py delete mode 100644 local.py delete mode 100644 noawait.py rename requirements-embedding.txt => requirements/embedding.txt (54%) delete mode 100644 service/bert_embedding.py create mode 100644 service/text_embedding.py rename test/{bert_embedding_queue.py => text2vec_embedding_queue.py} (51%) create mode 100644 utils/local.py diff --git a/api/controller/ChatComplete.py b/api/controller/ChatComplete.py index 4d84a6f..e079861 100644 --- a/api/controller/ChatComplete.py +++ b/api/controller/ChatComplete.py @@ -6,7 +6,7 @@ from api.controller.task.ChatCompleteTask import ChatCompleteTask from api.model.base import clone_model from api.model.chat_complete.bot_persona import BotPersonaHelper from api.model.toolkit_ui.conversation import ConversationHelper -from local import noawait +from utils.local import noawait from aiohttp import web from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse diff --git a/api/controller/task/ChatCompleteTask.py b/api/controller/task/ChatCompleteTask.py index 8357ba8..c38c0c6 100644 --- a/api/controller/task/ChatCompleteTask.py +++ b/api/controller/task/ChatCompleteTask.py @@ -2,7 +2,7 @@ from __future__ import annotations import sys import time import traceback -from local import noawait +from utils.local import noawait from typing import Optional, Callable, Union from service.chat_complete import ( ChatCompleteService, diff --git a/api/model/chat_complete/bot_persona.py b/api/model/chat_complete/bot_persona.py index 05e9448..7d34670 100644 --- a/api/model/chat_complete/bot_persona.py +++ b/api/model/chat_complete/bot_persona.py @@ -28,6 +28,7 @@ class BotPersonaModel(BaseModel): ), index=True, ) + api_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) system_prompt: Mapped[str] = mapped_column(sqlalchemy.String) message_log: Mapped[list] = mapped_column(sqlalchemy.JSON) default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) diff --git a/api/model/embedding_search/page_index.py b/api/model/embedding_search/page_index.py index 16bdaf7..2baaeee 100644 --- a/api/model/embedding_search/page_index.py +++ b/api/model/embedding_search/page_index.py @@ -6,7 +6,7 @@ import asyncpg from api.model.base import BaseModel import numpy as np import sqlalchemy -from config import Config +from lib.config import Config from sqlalchemy import Index, select, update, delete, Select from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.ext.asyncio import AsyncSession @@ -17,7 +17,7 @@ from service.database import DatabaseService page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {} -embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int) +embedding_vector_size = Config.get("embedding.vector_size", 1536, int) class AbstractPageIndexModel(BaseModel): __abstract__ = True diff --git a/api/model/embedding_search/title_index.py b/api/model/embedding_search/title_index.py index 568c388..2ecad92 100644 --- a/api/model/embedding_search/title_index.py +++ b/api/model/embedding_search/title_index.py @@ -8,11 +8,11 @@ import sqlalchemy from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer from sqlalchemy.ext.asyncio import AsyncEngine -from config import Config +from lib.config import Config from api.model.base import BaseHelper, BaseModel from service.database import DatabaseService -embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int) +embedding_vector_size = Config.get("embedding.vector_size", 1536, int) class TitleIndexModel(BaseModel): __tablename__ = "embedding_search_title_index" diff --git a/config.py b/config.py deleted file mode 100644 index dd99ab5..0000000 --- a/config.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations -from typing import TypeVar -import toml - -class Config: - values: dict = {} - - @staticmethod - def load_config(file): - with open(file, "r", encoding="utf-8") as f: - Config.values = toml.load(f) - - @staticmethod - def get(key: str, default=None, type=None, empty_is_none=False): - key_path = key.split(".") - value = Config.values - for k in key_path: - if k in value: - value = value[k] - else: - return default - - if empty_is_none and value == "": - return None - - if type == bool: - if isinstance(value, bool): - return value - elif isinstance(value, int) or isinstance(value, float): - return value != 0 - else: - return str(value).lower() in ("yes", "true", "1") - elif type == int: - return int(value) - elif type == float: - return float(value) - elif type == str: - return str(value) - elif type == list: - if not isinstance(value, list): - return [] - elif type == dict: - if not isinstance(value, dict): - return {} - else: - return value - - @staticmethod - def set(key: str, value): - key_path = key.split(".") - obj = Config.values - for k in key_path[:-1]: - if k not in obj: - obj[k] = {} - obj = obj[k] - obj[key_path[-1]] = value diff --git a/local.py b/local.py deleted file mode 100644 index 9febcce..0000000 --- a/local.py +++ /dev/null @@ -1,6 +0,0 @@ -import asyncio -from noawait import NoAwaitPool - -debug = False -loop = asyncio.new_event_loop() -noawait = NoAwaitPool(loop) \ No newline at end of file diff --git a/main.py b/main.py index 594ff26..dfff894 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,12 @@ import sys import traceback -from config import Config +from lib.config import Config Config.load_config("config.toml") -from local import loop, noawait +from utils.local import loop, noawait from aiohttp import web -import local +import utils.local as local import api.route import utils.web from service.database import DatabaseService diff --git a/maintenance/update_title_index.py b/maintenance/update_title_index.py index 7114ba3..b1123cb 100644 --- a/maintenance/update_title_index.py +++ b/maintenance/update_title_index.py @@ -1,7 +1,7 @@ import asyncio import sys import base as _ -import local +import utils.local as local from service.database import DatabaseService from service.embedding_search import EmbeddingSearchService diff --git a/noawait.py b/noawait.py deleted file mode 100644 index e0e6c20..0000000 --- a/noawait.py +++ /dev/null @@ -1,165 +0,0 @@ -from __future__ import annotations -from asyncio import AbstractEventLoop, Task -import asyncio -import atexit -from functools import wraps -import random - -import sys -import traceback -from typing import Callable, Coroutine, Optional, TypedDict - -class TimerInfo(TypedDict): - id: int - callback: Callable - interval: float - next_time: float - -class NoAwaitPool: - def __init__(self, loop: AbstractEventLoop): - self.task_list: list[Task] = [] - self.timer_map: dict[int, TimerInfo] = {} - self.loop = loop - self.running = True - - self.should_refresh_task = False - self.next_timer_time: Optional[float] = None - - self.on_error: list[Callable] = [] - - self.gc_task = loop.create_task(self._run_gc()) - self.timer_task = loop.create_task(self._run_timer()) - - atexit.register(self.end_task) - - async def end(self): - if self.running: - print("Stopping NoAwait Tasks...") - self.running = False - for task in self.task_list: - await self._finish_task(task) - - await self.gc_task - await self.timer_task - - def end_task(self): - if self.running and not self.loop.is_closed(): - self.loop.run_until_complete(self.end()) - - async def _wrap_task(self, task: Task): - try: - await task - except Exception as e: - handled = False - for handler in self.on_error: - try: - handler_ret = handler(e) - await handler_ret - handled = True - except Exception as handler_err: - print("Exception on error handler: " + str(handler_err), file=sys.stderr) - traceback.print_exc() - - if not handled: - print(e, file=sys.stderr) - traceback.print_exc() - finally: - self.should_refresh_task = True - - def add_task(self, coroutine: Coroutine): - task = self.loop.create_task(coroutine) - self.task_list.append(task) - - def add_timer(self, callback: Callable, interval: float) -> int: - id = random.randint(0, 1000000000) - while id in self.timer_map: - id = random.randint(0, 1000000000) - - now = self.loop.time() - next_time = now + interval - self.timer_map[id] = { - "id": id, - "callback": callback, - "interval": interval, - "next_time": next_time - } - - if self.next_timer_time is None or next_time < self.next_timer_time: - self.next_timer_time = next_time - - return id - - def remove_timer(self, id: int): - if id in self.timer_map: - del self.timer_map[id] - - def wrap(self, f): - @wraps(f) - def decorated_function(*args, **kwargs): - coroutine = f(*args, **kwargs) - self.add_task(coroutine) - - return decorated_function - - async def _finish_task(self, task: Task): - try: - if not task.done(): - task.cancel() - await task - except Exception as e: - handled = False - for handler in self.on_error: - try: - handler_ret = handler(e) - await handler_ret - handled = True - except Exception as handler_err: - print("Exception on error handler: " + str(handler_err), file=sys.stderr) - traceback.print_exc() - - if not handled: - print(e, file=sys.stderr) - traceback.print_exc() - - async def _run_gc(self): - while self.running: - if self.should_refresh_task: - should_remove = [] - for task in self.task_list: - if task.done(): - await self._finish_task(task) - should_remove.append(task) - for task in should_remove: - self.task_list.remove(task) - - await asyncio.sleep(0.1) - - async def _run_timer(self): - while self.running: - now = self.loop.time() - if self.next_timer_time is not None and now >= self.next_timer_time: - self.next_timer_time = None - for timer in self.timer_map.values(): - if now >= timer["next_time"]: - timer["next_time"] = now + timer["interval"] - try: - result = timer["callback"]() - self.add_task(result) - except Exception as e: - handled = False - for handler in self.on_error: - try: - handler_ret = handler(e) - self.add_task(handler_ret) - handled = True - except Exception as handler_err: - print("Exception on error handler: " + str(handler_err), file=sys.stderr) - traceback.print_exc() - - if not handled: - print(e, file=sys.stderr) - traceback.print_exc() - if self.next_timer_time is None or timer["next_time"] < self.next_timer_time: - self.next_timer_time = timer["next_time"] - - await asyncio.sleep(0.1) \ No newline at end of file diff --git a/requirements-embedding.txt b/requirements/embedding.txt similarity index 54% rename from requirements-embedding.txt rename to requirements/embedding.txt index 3958881..820eef8 100644 --- a/requirements-embedding.txt +++ b/requirements/embedding.txt @@ -1,5 +1,3 @@ -transformers +text2vec>=1.2.9 --index-url https://download.pytorch.org/whl/cpu -torch -torchvision -torchaudio \ No newline at end of file +torch \ No newline at end of file diff --git a/service/bert_embedding.py b/service/bert_embedding.py deleted file mode 100644 index c5cdece..0000000 --- a/service/bert_embedding.py +++ /dev/null @@ -1,143 +0,0 @@ -from __future__ import annotations -import time -from config import Config -import asyncio -import random -import threading -from typing import Optional, TypedDict -import torch -from transformers import pipeline -from local import loop - -from service.tiktoken import TikTokenService - -BERT_EMBEDDING_QUEUE_TIMEOUT = 1 - -class BERTEmbeddingQueueTaskInfo(TypedDict): - task_id: int - text: str - embedding: torch.Tensor - -class BERTEmbeddingQueue: - def init(self): - self.embedding_model = pipeline("feature-extraction", model="bert-base-chinese") - self.task_map: dict[int, BERTEmbeddingQueueTaskInfo] = {} - self.task_list: list[BERTEmbeddingQueueTaskInfo] = [] - self.lock = threading.Lock() - - self.thread: Optional[threading.Thread] = None - self.running = False - - async def get_embeddings(self, text: str): - text = "[CLS]" + text + "[SEP]" - task_id = random.randint(0, 1000000000) - with self.lock: - while task_id in self.task_map: - task_id = random.randint(0, 1000000000) - - task_info = { - "task_id": task_id, - "text": text, - "embedding": None - } - self.task_map[task_id] = task_info - self.task_list.append(task_info) - - self.start_queue() - while True: - task_info = self.pop_task(task_id) - if task_info is not None: - return task_info["embedding"] - - await asyncio.sleep(0.01) - - def pop_task(self, task_id): - with self.lock: - if task_id in self.task_map: - task_info = self.task_map[task_id] - if task_info["embedding"] is not None: - del self.task_map[task_id] - return task_info - - return None - - def run(self): - running = True - last_task_time = None - while running and self.running: - current_time = time.time() - task = None - with self.lock: - if len(self.task_list) > 0: - task = self.task_list.pop(0) - - if task is not None: - embeddings = self.embedding_model(task["text"]) - - with self.lock: - task["embedding"] = embeddings[0][1] - - last_task_time = time.time() - elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT: - self.thread = None - self.running = False - running = False - else: - time.sleep(0.01) - - def start_queue(self): - if not self.running: - self.running = True - self.thread = threading.Thread(target=self.run) - self.thread.start() - -bert_embedding_queue = BERTEmbeddingQueue() -bert_embedding_queue.init() - -class BERTEmbeddingService: - instance = None - - @staticmethod - async def create() -> BERTEmbeddingService: - if BERTEmbeddingService.instance is None: - BERTEmbeddingService.instance = BERTEmbeddingService() - await BERTEmbeddingService.instance.init() - return BERTEmbeddingService.instance - - async def init(self): - self.tiktoken = await TikTokenService.create() - self.embedding_queue = BERTEmbeddingQueue() - await loop.run_in_executor(None, self.embedding_queue.init) - - async def get_embeddings(self, docs, on_progress=None): - if len(docs) == 0: - return ([], 0) - - if on_progress is not None: - await on_progress(0, len(docs)) - - embeddings = [] - token_usage = 0 - - for doc in docs: - if "text" in doc: - tokens = await self.tiktoken.get_tokens(doc["text"]) - token_usage += tokens - embeddings.append({ - "id": doc["id"], - "text": doc["text"], - "embedding": self.model.encode(doc["text"]), - "tokens": tokens - }) - else: - embeddings.append({ - "id": doc["id"], - "text": doc["text"], - "embedding": None, - "tokens": 0 - }) - - if on_progress is not None: - await on_progress(1, len(docs)) - - return (embeddings, token_usage) \ No newline at end of file diff --git a/service/chat_complete.py b/service/chat_complete.py index af60196..1726e7c 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -11,7 +11,7 @@ from api.model.chat_complete.conversation import ( import sys from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel -from config import Config +from lib.config import Config import utils.config, utils.web from aiohttp import web @@ -20,7 +20,7 @@ from sqlalchemy.orm.attributes import flag_modified from service.database import DatabaseService from service.embedding_search import EmbeddingSearchArgs, EmbeddingSearchService from service.mediawiki_api import MediaWikiApi -from service.openai_api import OpenAIApi +from service.openai_api import OpenAIApi, OpenAIApiTypeInvalidException from service.tiktoken import TikTokenService class ChatCompleteQuestionTooLongException(Exception): @@ -34,6 +34,7 @@ class ChatCompleteServicePrepareResponse(TypedDict): question_tokens: int conversation_id: int chunk_id: int + api_id: str class ChatCompleteServiceResponse(TypedDict): message: str @@ -60,12 +61,12 @@ class ChatCompleteService: self.conversation_info: Optional[ConversationModel] = None self.conversation_chunk: Optional[ConversationChunkModel] = None + self.openai_api: OpenAIApi = None self.tiktoken: TikTokenService = None self.extract_doc: list = None self.mwapi = MediaWikiApi.create() - self.openai_api = OpenAIApi.create() self.user_id = 0 self.question = "" @@ -155,6 +156,12 @@ class ChatCompleteService: bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id) else: self.chat_system_prompt = bot_persona.system_prompt + + default_api = Config.get("chatcomplete.default_api", None, str) + try: + self.openai_api = OpenAIApi.create(bot_persona.api_id or default_api) + except OpenAIApiTypeInvalidException: + self.openai_api = OpenAIApi.create(default_api) self.conversation_chunk = None if self.conversation_info is not None: diff --git a/service/database.py b/service/database.py index da05ad2..6d8f1e9 100644 --- a/service/database.py +++ b/service/database.py @@ -1,10 +1,10 @@ from __future__ import annotations -import local +import utils.local as local from urllib.parse import quote_plus from aiohttp import web import asyncpg from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine -from config import Config +from lib.config import Config def get_dsn(): db_conf = Config.get("database") diff --git a/service/embedding_search.py b/service/embedding_search.py index 7c41774..7f4a52a 100644 --- a/service/embedding_search.py +++ b/service/embedding_search.py @@ -11,6 +11,7 @@ from api.model.embedding_search.page_index import PageIndexHelper from service.database import DatabaseService from service.mediawiki_api import MediaWikiApi from service.openai_api import OpenAIApi +from service.text_embedding import TextEmbeddingService from service.tiktoken import TikTokenService from utils.wiki import getWikiSentences @@ -38,6 +39,7 @@ class EmbeddingSearchService: self.title_collection_helper = TitleCollectionHelper(dbs) self.page_index: PageIndexHelper = None + self.text_embedding: TextEmbeddingService = None self.tiktoken: TikTokenService = None self.mwapi = MediaWikiApi.create() @@ -53,6 +55,7 @@ class EmbeddingSearchService: self.unindexed_docs: list = None async def __aenter__(self): + self.text_embedding = await TextEmbeddingService.create() self.tiktoken = await TikTokenService.create() await self.title_index_helper.__aenter__() @@ -225,7 +228,7 @@ class EmbeddingSearchService: await on_progress(indexed_docs, len(self.unindexed_docs)) async def embedding_doc(doc_chunk): - (doc_chunk, token_usage) = await self.openai_api.get_embeddings( + (doc_chunk, token_usage) = await self.text_embedding.get_embeddings( doc_chunk, on_embedding_progress ) await self.page_index.index_doc(doc_chunk, self.page_id) @@ -266,7 +269,7 @@ class EmbeddingSearchService: # Update title embedding if await self.title_index.awaitable_attrs.embedding is None: doc_chunk = [{"text": self.title}] - (doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk) + (doc_chunk, token_usage) = await self.text_embedding.get_embeddings(doc_chunk) total_token_usage += token_usage embedding = doc_chunk[0]["embedding"] @@ -290,7 +293,7 @@ class EmbeddingSearchService: raise Exception("Page index is not initialized") query_doc = [{"text": query}] - query_doc, token_usage = await self.openai_api.get_embeddings(query_doc) + query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc) query_embedding = query_doc[0]["embedding"] if query_embedding is None: diff --git a/service/mediawiki_api.py b/service/mediawiki_api.py index ac35ebb..24cccbe 100644 --- a/service/mediawiki_api.py +++ b/service/mediawiki_api.py @@ -4,7 +4,7 @@ import sys import time from typing import Optional, TypedDict import aiohttp -from config import Config +from lib.config import Config class MediaWikiApiException(Exception): def __init__(self, info: str, code: Optional[str] = None) -> None: diff --git a/service/openai_api.py b/service/openai_api.py index d81320e..97a7f62 100644 --- a/service/openai_api.py +++ b/service/openai_api.py @@ -4,7 +4,7 @@ import json from typing import Callable, Optional, TypedDict import aiohttp -from config import Config +from lib.config import Config import numpy as np from aiohttp_sse_client2 import client as sse_client @@ -13,6 +13,13 @@ from service.tiktoken import TikTokenService AZURE_CHATCOMPLETE_API_VERSION = "2023-07-01-preview" AZURE_EMBEDDING_API_VERSION = "2023-05-15" +class OpenAIApiTypeInvalidException(Exception): + def __init__(self, api_id: str): + self.api_id = api_id + + def __str__(self): + return f"Invalid api_id: {self.api_id}" + class ChatCompleteMessageLog(TypedDict): role: str content: str @@ -26,19 +33,21 @@ class ChatCompleteResponse(TypedDict): class OpenAIApi: @staticmethod - def create(): - return OpenAIApi() + def create(api_id: str) -> OpenAIApi: + return OpenAIApi(api_id) - def __init__(self): - self.api_type = Config.get("chatcomplete.api_type", "openai", str) - self.request_proxy = Config.get("chatcomplete.request_proxy", type=str, empty_is_none=True) + def __init__(self, api_id: str): + self.api_id = api_id - if self.api_type == "azure": - self.api_url = Config.get("chatcomplete.azure.api_endpoint", type=str) - self.api_key = Config.get("chatcomplete.azure.key", type=str) - else: - self.api_url = Config.get("chatcomplete.openai.api_endpoint", type=str) - self.api_key = Config.get("chatcomplete.openai.key", type=str) + self.api_type = Config.get(f"chatcomplete.{api_id}.api_type", None, str) + + if self.api_type is None: + raise OpenAIApiTypeInvalidException(api_id) + + self.request_proxy = Config.get(f"chatcomplete.{api_id}.request_proxy", type=str, empty_is_none=True) + + self.api_url = Config.get(f"chatcomplete.{api_id}.api_endpoint", type=str) + self.api_key = Config.get(f"chatcomplete.{api_id}.key", type=str) def build_header(self): if self.api_type == "azure": @@ -56,7 +65,7 @@ class OpenAIApi: def get_url(self, method: str): if self.api_type == "azure": - deployments = Config.get("chatcomplete.azure.deployments") + deployments = Config.get(f"chatcomplete.{self.api_id}.deployments") if method == "chat/completions": return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method elif method == "embeddings": @@ -65,25 +74,10 @@ class OpenAIApi: return self.api_url + "/v1/" + method async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): - text_list = [] - regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]" - for doc in doc_list: - text: str = doc["text"] - text = text.replace("\r\n", "\n").replace("\r", "\n") - if "\n" in text: - lines = text.split("\n") - new_lines = [] - for line in lines: - line = line.strip() - # Add a dot at the end of the line if it doesn't end with a punctuation mark - if len(line) > 0 and regex.find(line[-1]) == -1: - line += "." - new_lines.append(line) - text = " ".join(new_lines) - text_list.append(text) - token_usage = 0 + text_list = [doc["text"] for doc in doc_list] + async with aiohttp.ClientSession() as session: url = self.get_url("embeddings") params = {} diff --git a/service/text_embedding.py b/service/text_embedding.py new file mode 100644 index 0000000..cf79bb5 --- /dev/null +++ b/service/text_embedding.py @@ -0,0 +1,152 @@ +from __future__ import annotations +from copy import deepcopy +import time +from lib.config import Config +import asyncio +import random +import threading +from typing import Callable, Optional, TypedDict +import torch +from text2vec import SentenceModel +from utils.local import loop +from service.openai_api import OpenAIApi +from service.tiktoken import TikTokenService + +BERT_EMBEDDING_QUEUE_TIMEOUT = 1 + +class Text2VecEmbeddingQueueTaskInfo(TypedDict): + task_id: int + text: str + embedding: torch.Tensor + +class Text2VecEmbeddingQueue: + def __init__(self, model: str) -> None: + self.model_name = model + + self.embedding_model = SentenceModel(self.model_name) + self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {} + self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = [] + self.lock = threading.Lock() + + self.thread: Optional[threading.Thread] = None + self.running = False + + + async def get_embeddings(self, text: str): + task_id = random.randint(0, 1000000000) + with self.lock: + while task_id in self.task_map: + task_id = random.randint(0, 1000000000) + + task_info = { + "task_id": task_id, + "text": text, + "embedding": None + } + self.task_map[task_id] = task_info + self.task_list.append(task_info) + + self.start_queue() + while True: + task_info = self.pop_task(task_id) + if task_info is not None: + return task_info["embedding"] + + await asyncio.sleep(0.01) + + def pop_task(self, task_id): + with self.lock: + if task_id in self.task_map: + task_info = self.task_map[task_id] + if task_info["embedding"] is not None: + del self.task_map[task_id] + return task_info + + return None + + def run(self): + running = True + last_task_time = None + while running and self.running: + current_time = time.time() + task = None + with self.lock: + if len(self.task_list) > 0: + task = self.task_list.pop(0) + + if task is not None: + embeddings = self.embedding_model.encode([task["text"]]) + + with self.lock: + task["embedding"] = embeddings[0] + + last_task_time = time.time() + elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT: + self.thread = None + self.running = False + running = False + else: + time.sleep(0.01) + + def start_queue(self): + if not self.running: + self.running = True + self.thread = threading.Thread(target=self.run) + self.thread.start() + +class TextEmbeddingService: + instance = None + + @staticmethod + async def create() -> TextEmbeddingService: + if TextEmbeddingService.instance is None: + TextEmbeddingService.instance = TextEmbeddingService() + await TextEmbeddingService.instance.init() + return TextEmbeddingService.instance + + async def init(self): + self.tiktoken = await TikTokenService.create() + + self.embedding_type = Config.get("embedding.type", "text2vec") + + if self.embedding_type == "text2vec": + embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese") + self.embedding_queue = Text2VecEmbeddingQueue(model=embedding_model) + elif self.embedding_type == "openai": + self.openai_api: OpenAIApi = await OpenAIApi.create() + + await loop.run_in_executor(None, self.embedding_queue.init) + + async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): + for index, doc in enumerate(doc_list): + text = doc["text"] + embedding = await self.embedding_queue.get_embeddings(text) + doc["embedding"] = embedding + + if on_index_progress is not None: + await on_index_progress(index, len(doc_list)) + + + async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): + res_doc_list = deepcopy(doc_list) + + regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]" + for doc in res_doc_list: + text: str = doc["text"] + text = text.replace("\r\n", "\n").replace("\r", "\n") + if "\n" in text: + lines = text.split("\n") + new_lines = [] + for line in lines: + line = line.strip() + # Add a dot at the end of the line if it doesn't end with a punctuation mark + if len(line) > 0 and regex.find(line[-1]) == -1: + line += "." + new_lines.append(line) + text = " ".join(new_lines) + doc["text"] = text + + if self.embedding_type == "text2vec": + return await self.get_text2vec_embeddings(res_doc_list, on_index_progress) + elif self.embedding_type == "openai": + return await self.openai_api.get_embeddings(res_doc_list, on_index_progress) \ No newline at end of file diff --git a/test/base.py b/test/base.py index 113ba39..74c4dae 100644 --- a/test/base.py +++ b/test/base.py @@ -2,9 +2,9 @@ import sys import pathlib root_path = pathlib.Path(__file__).parent.parent -sys.path.append(root_path) +sys.path.append(".") -from config import Config +from lib.config import Config -Config.load_config(root_path + "/config.toml") +Config.load_config(str(root_path) + "/config.toml") Config.set("server.debug", True) \ No newline at end of file diff --git a/test/chatcomplete.py b/test/chatcomplete.py index 3fc840c..c00360e 100644 --- a/test/chatcomplete.py +++ b/test/chatcomplete.py @@ -1,6 +1,6 @@ import traceback import base -import local +import utils.local as local from service.chat_complete import ChatCompleteService from service.database import DatabaseService diff --git a/test/create_token.py b/test/create_token.py index 1e63e63..345bd5e 100644 --- a/test/create_token.py +++ b/test/create_token.py @@ -3,7 +3,7 @@ import base from sqlalchemy import select from api.model.embedding_search.title_index import TitleIndexModel -import local +import utils.local as local from service.database import DatabaseService from service.embedding_search import EmbeddingSearchService diff --git a/test/embedding_search.py b/test/embedding_search.py index 5be0875..969a247 100644 --- a/test/embedding_search.py +++ b/test/embedding_search.py @@ -1,5 +1,5 @@ import base -import local +import utils.local as local from service.database import DatabaseService from service.embedding_search import EmbeddingSearchService diff --git a/test/bert_embedding_queue.py b/test/text2vec_embedding_queue.py similarity index 51% rename from test/bert_embedding_queue.py rename to test/text2vec_embedding_queue.py index a5b0cd2..b9ce694 100644 --- a/test/bert_embedding_queue.py +++ b/test/text2vec_embedding_queue.py @@ -1,13 +1,21 @@ import asyncio import time -import base -from local import loop, noawait -from service.bert_embedding import bert_embedding_queue +import base as _ +from utils.local import loop, noawait +from service.text_embedding import Text2VecEmbeddingQueue async def main(): embedding_list = [] - start_time = time.time() queue = [] + text2vec_queue = Text2VecEmbeddingQueue("shibing624/text2vec-base-chinese") + + start_time = time.time() + + async def on_progress(current, total): + print(f"{current}/{total}") + + total_lines = 0 + with open("test/test.md", "r", encoding="utf-8") as fp: text = fp.read() lines = text.split("\n") @@ -16,10 +24,13 @@ async def main(): if line == "": continue - queue.append(bert_embedding_queue.get_embeddings(line)) + queue.append(text2vec_queue.get_embeddings(line)) + total_lines += 0 embedding_list = await asyncio.gather(*queue) end_time = time.time() + print("total lines: %d" % total_lines) print("time cost: %.4f" % (end_time - start_time)) + print("speed: %.4f it/s" % (total_lines / (end_time - start_time))) print("dimensions: %d" % len(embedding_list[0])) await noawait.end() diff --git a/test/timer.py b/test/timer.py index b4c9e56..4078545 100644 --- a/test/timer.py +++ b/test/timer.py @@ -1,6 +1,6 @@ import asyncio import base -from local import loop, noawait +from utils.local import loop, noawait async def test_timer1(): print("timer1") diff --git a/test/title_index.py b/test/title_index.py index a339782..0a5a127 100644 --- a/test/title_index.py +++ b/test/title_index.py @@ -3,7 +3,7 @@ import base from sqlalchemy import select from api.model.embedding_search.title_index import TitleIndexModel -import local +import utils.local as local from service.database import DatabaseService from service.embedding_search import EmbeddingSearchService diff --git a/utils/config.py b/utils/config.py index e0b2fc6..e30f87a 100644 --- a/utils/config.py +++ b/utils/config.py @@ -1,5 +1,5 @@ import time -from config import Config +from lib.config import Config def get_prompt(name: str, type: str, params: dict = {}): sys_params = { diff --git a/utils/local.py b/utils/local.py new file mode 100644 index 0000000..3c69ea0 --- /dev/null +++ b/utils/local.py @@ -0,0 +1,6 @@ +import asyncio +from lib.noawait import NoAwaitPool + +debug = False +loop = asyncio.new_event_loop() +noawait = NoAwaitPool(loop) \ No newline at end of file diff --git a/utils/web.py b/utils/web.py index e3f289e..d7109e1 100644 --- a/utils/web.py +++ b/utils/web.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Dict from aiohttp import web import jwt import uuid -from config import Config +from lib.config import Config ParamRule = Dict[str, Any]