完成ChatComplete重构

master
落雨楓 9 months ago
parent 2a1e5c1589
commit 1355dbbf35

@ -6,7 +6,7 @@ from api.controller.task.ChatCompleteTask import ChatCompleteTask
from api.model.base import clone_model
from api.model.chat_complete.bot_persona import BotPersonaHelper
from api.model.toolkit_ui.conversation import ConversationHelper
from local import noawait
from utils.local import noawait
from aiohttp import web
from api.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel
from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse

@ -2,7 +2,7 @@ from __future__ import annotations
import sys
import time
import traceback
from local import noawait
from utils.local import noawait
from typing import Optional, Callable, Union
from service.chat_complete import (
ChatCompleteService,

@ -28,6 +28,7 @@ class BotPersonaModel(BaseModel):
),
index=True,
)
api_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)

@ -6,7 +6,7 @@ import asyncpg
from api.model.base import BaseModel
import numpy as np
import sqlalchemy
from config import Config
from lib.config import Config
from sqlalchemy import Index, select, update, delete, Select
from sqlalchemy.orm import mapped_column, Mapped
from sqlalchemy.ext.asyncio import AsyncSession
@ -17,7 +17,7 @@ from service.database import DatabaseService
page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}
embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int)
embedding_vector_size = Config.get("embedding.vector_size", 1536, int)
class AbstractPageIndexModel(BaseModel):
__abstract__ = True

@ -8,11 +8,11 @@ import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer
from sqlalchemy.ext.asyncio import AsyncEngine
from config import Config
from lib.config import Config
from api.model.base import BaseHelper, BaseModel
from service.database import DatabaseService
embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int)
embedding_vector_size = Config.get("embedding.vector_size", 1536, int)
class TitleIndexModel(BaseModel):
__tablename__ = "embedding_search_title_index"

@ -1,56 +0,0 @@
from __future__ import annotations
from typing import TypeVar
import toml
class Config:
values: dict = {}
@staticmethod
def load_config(file):
with open(file, "r", encoding="utf-8") as f:
Config.values = toml.load(f)
@staticmethod
def get(key: str, default=None, type=None, empty_is_none=False):
key_path = key.split(".")
value = Config.values
for k in key_path:
if k in value:
value = value[k]
else:
return default
if empty_is_none and value == "":
return None
if type == bool:
if isinstance(value, bool):
return value
elif isinstance(value, int) or isinstance(value, float):
return value != 0
else:
return str(value).lower() in ("yes", "true", "1")
elif type == int:
return int(value)
elif type == float:
return float(value)
elif type == str:
return str(value)
elif type == list:
if not isinstance(value, list):
return []
elif type == dict:
if not isinstance(value, dict):
return {}
else:
return value
@staticmethod
def set(key: str, value):
key_path = key.split(".")
obj = Config.values
for k in key_path[:-1]:
if k not in obj:
obj[k] = {}
obj = obj[k]
obj[key_path[-1]] = value

@ -1,6 +0,0 @@
import asyncio
from noawait import NoAwaitPool
debug = False
loop = asyncio.new_event_loop()
noawait = NoAwaitPool(loop)

@ -1,12 +1,12 @@
import sys
import traceback
from config import Config
from lib.config import Config
Config.load_config("config.toml")
from local import loop, noawait
from utils.local import loop, noawait
from aiohttp import web
import local
import utils.local as local
import api.route
import utils.web
from service.database import DatabaseService

@ -1,7 +1,7 @@
import asyncio
import sys
import base as _
import local
import utils.local as local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService

@ -1,165 +0,0 @@
from __future__ import annotations
from asyncio import AbstractEventLoop, Task
import asyncio
import atexit
from functools import wraps
import random
import sys
import traceback
from typing import Callable, Coroutine, Optional, TypedDict
class TimerInfo(TypedDict):
id: int
callback: Callable
interval: float
next_time: float
class NoAwaitPool:
def __init__(self, loop: AbstractEventLoop):
self.task_list: list[Task] = []
self.timer_map: dict[int, TimerInfo] = {}
self.loop = loop
self.running = True
self.should_refresh_task = False
self.next_timer_time: Optional[float] = None
self.on_error: list[Callable] = []
self.gc_task = loop.create_task(self._run_gc())
self.timer_task = loop.create_task(self._run_timer())
atexit.register(self.end_task)
async def end(self):
if self.running:
print("Stopping NoAwait Tasks...")
self.running = False
for task in self.task_list:
await self._finish_task(task)
await self.gc_task
await self.timer_task
def end_task(self):
if self.running and not self.loop.is_closed():
self.loop.run_until_complete(self.end())
async def _wrap_task(self, task: Task):
try:
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
finally:
self.should_refresh_task = True
def add_task(self, coroutine: Coroutine):
task = self.loop.create_task(coroutine)
self.task_list.append(task)
def add_timer(self, callback: Callable, interval: float) -> int:
id = random.randint(0, 1000000000)
while id in self.timer_map:
id = random.randint(0, 1000000000)
now = self.loop.time()
next_time = now + interval
self.timer_map[id] = {
"id": id,
"callback": callback,
"interval": interval,
"next_time": next_time
}
if self.next_timer_time is None or next_time < self.next_timer_time:
self.next_timer_time = next_time
return id
def remove_timer(self, id: int):
if id in self.timer_map:
del self.timer_map[id]
def wrap(self, f):
@wraps(f)
def decorated_function(*args, **kwargs):
coroutine = f(*args, **kwargs)
self.add_task(coroutine)
return decorated_function
async def _finish_task(self, task: Task):
try:
if not task.done():
task.cancel()
await task
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
await handler_ret
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
async def _run_gc(self):
while self.running:
if self.should_refresh_task:
should_remove = []
for task in self.task_list:
if task.done():
await self._finish_task(task)
should_remove.append(task)
for task in should_remove:
self.task_list.remove(task)
await asyncio.sleep(0.1)
async def _run_timer(self):
while self.running:
now = self.loop.time()
if self.next_timer_time is not None and now >= self.next_timer_time:
self.next_timer_time = None
for timer in self.timer_map.values():
if now >= timer["next_time"]:
timer["next_time"] = now + timer["interval"]
try:
result = timer["callback"]()
self.add_task(result)
except Exception as e:
handled = False
for handler in self.on_error:
try:
handler_ret = handler(e)
self.add_task(handler_ret)
handled = True
except Exception as handler_err:
print("Exception on error handler: " + str(handler_err), file=sys.stderr)
traceback.print_exc()
if not handled:
print(e, file=sys.stderr)
traceback.print_exc()
if self.next_timer_time is None or timer["next_time"] < self.next_timer_time:
self.next_timer_time = timer["next_time"]
await asyncio.sleep(0.1)

@ -1,5 +1,3 @@
transformers
text2vec>=1.2.9
--index-url https://download.pytorch.org/whl/cpu
torch
torchvision
torchaudio

@ -1,143 +0,0 @@
from __future__ import annotations
import time
from config import Config
import asyncio
import random
import threading
from typing import Optional, TypedDict
import torch
from transformers import pipeline
from local import loop
from service.tiktoken import TikTokenService
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
class BERTEmbeddingQueueTaskInfo(TypedDict):
task_id: int
text: str
embedding: torch.Tensor
class BERTEmbeddingQueue:
def init(self):
self.embedding_model = pipeline("feature-extraction", model="bert-base-chinese")
self.task_map: dict[int, BERTEmbeddingQueueTaskInfo] = {}
self.task_list: list[BERTEmbeddingQueueTaskInfo] = []
self.lock = threading.Lock()
self.thread: Optional[threading.Thread] = None
self.running = False
async def get_embeddings(self, text: str):
text = "[CLS]" + text + "[SEP]"
task_id = random.randint(0, 1000000000)
with self.lock:
while task_id in self.task_map:
task_id = random.randint(0, 1000000000)
task_info = {
"task_id": task_id,
"text": text,
"embedding": None
}
self.task_map[task_id] = task_info
self.task_list.append(task_info)
self.start_queue()
while True:
task_info = self.pop_task(task_id)
if task_info is not None:
return task_info["embedding"]
await asyncio.sleep(0.01)
def pop_task(self, task_id):
with self.lock:
if task_id in self.task_map:
task_info = self.task_map[task_id]
if task_info["embedding"] is not None:
del self.task_map[task_id]
return task_info
return None
def run(self):
running = True
last_task_time = None
while running and self.running:
current_time = time.time()
task = None
with self.lock:
if len(self.task_list) > 0:
task = self.task_list.pop(0)
if task is not None:
embeddings = self.embedding_model(task["text"])
with self.lock:
task["embedding"] = embeddings[0][1]
last_task_time = time.time()
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
self.thread = None
self.running = False
running = False
else:
time.sleep(0.01)
def start_queue(self):
if not self.running:
self.running = True
self.thread = threading.Thread(target=self.run)
self.thread.start()
bert_embedding_queue = BERTEmbeddingQueue()
bert_embedding_queue.init()
class BERTEmbeddingService:
instance = None
@staticmethod
async def create() -> BERTEmbeddingService:
if BERTEmbeddingService.instance is None:
BERTEmbeddingService.instance = BERTEmbeddingService()
await BERTEmbeddingService.instance.init()
return BERTEmbeddingService.instance
async def init(self):
self.tiktoken = await TikTokenService.create()
self.embedding_queue = BERTEmbeddingQueue()
await loop.run_in_executor(None, self.embedding_queue.init)
async def get_embeddings(self, docs, on_progress=None):
if len(docs) == 0:
return ([], 0)
if on_progress is not None:
await on_progress(0, len(docs))
embeddings = []
token_usage = 0
for doc in docs:
if "text" in doc:
tokens = await self.tiktoken.get_tokens(doc["text"])
token_usage += tokens
embeddings.append({
"id": doc["id"],
"text": doc["text"],
"embedding": self.model.encode(doc["text"]),
"tokens": tokens
})
else:
embeddings.append({
"id": doc["id"],
"text": doc["text"],
"embedding": None,
"tokens": 0
})
if on_progress is not None:
await on_progress(1, len(docs))
return (embeddings, token_usage)

@ -11,7 +11,7 @@ from api.model.chat_complete.conversation import (
import sys
from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel
from config import Config
from lib.config import Config
import utils.config, utils.web
from aiohttp import web
@ -20,7 +20,7 @@ from sqlalchemy.orm.attributes import flag_modified
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs, EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi
from service.openai_api import OpenAIApi, OpenAIApiTypeInvalidException
from service.tiktoken import TikTokenService
class ChatCompleteQuestionTooLongException(Exception):
@ -34,6 +34,7 @@ class ChatCompleteServicePrepareResponse(TypedDict):
question_tokens: int
conversation_id: int
chunk_id: int
api_id: str
class ChatCompleteServiceResponse(TypedDict):
message: str
@ -60,12 +61,12 @@ class ChatCompleteService:
self.conversation_info: Optional[ConversationModel] = None
self.conversation_chunk: Optional[ConversationChunkModel] = None
self.openai_api: OpenAIApi = None
self.tiktoken: TikTokenService = None
self.extract_doc: list = None
self.mwapi = MediaWikiApi.create()
self.openai_api = OpenAIApi.create()
self.user_id = 0
self.question = ""
@ -156,6 +157,12 @@ class ChatCompleteService:
else:
self.chat_system_prompt = bot_persona.system_prompt
default_api = Config.get("chatcomplete.default_api", None, str)
try:
self.openai_api = OpenAIApi.create(bot_persona.api_id or default_api)
except OpenAIApiTypeInvalidException:
self.openai_api = OpenAIApi.create(default_api)
self.conversation_chunk = None
if self.conversation_info is not None:
chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id)

@ -1,10 +1,10 @@
from __future__ import annotations
import local
import utils.local as local
from urllib.parse import quote_plus
from aiohttp import web
import asyncpg
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from config import Config
from lib.config import Config
def get_dsn():
db_conf = Config.get("database")

@ -11,6 +11,7 @@ from api.model.embedding_search.page_index import PageIndexHelper
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi
from service.text_embedding import TextEmbeddingService
from service.tiktoken import TikTokenService
from utils.wiki import getWikiSentences
@ -38,6 +39,7 @@ class EmbeddingSearchService:
self.title_collection_helper = TitleCollectionHelper(dbs)
self.page_index: PageIndexHelper = None
self.text_embedding: TextEmbeddingService = None
self.tiktoken: TikTokenService = None
self.mwapi = MediaWikiApi.create()
@ -53,6 +55,7 @@ class EmbeddingSearchService:
self.unindexed_docs: list = None
async def __aenter__(self):
self.text_embedding = await TextEmbeddingService.create()
self.tiktoken = await TikTokenService.create()
await self.title_index_helper.__aenter__()
@ -225,7 +228,7 @@ class EmbeddingSearchService:
await on_progress(indexed_docs, len(self.unindexed_docs))
async def embedding_doc(doc_chunk):
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(
(doc_chunk, token_usage) = await self.text_embedding.get_embeddings(
doc_chunk, on_embedding_progress
)
await self.page_index.index_doc(doc_chunk, self.page_id)
@ -266,7 +269,7 @@ class EmbeddingSearchService:
# Update title embedding
if await self.title_index.awaitable_attrs.embedding is None:
doc_chunk = [{"text": self.title}]
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
(doc_chunk, token_usage) = await self.text_embedding.get_embeddings(doc_chunk)
total_token_usage += token_usage
embedding = doc_chunk[0]["embedding"]
@ -290,7 +293,7 @@ class EmbeddingSearchService:
raise Exception("Page index is not initialized")
query_doc = [{"text": query}]
query_doc, token_usage = await self.openai_api.get_embeddings(query_doc)
query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc)
query_embedding = query_doc[0]["embedding"]
if query_embedding is None:

@ -4,7 +4,7 @@ import sys
import time
from typing import Optional, TypedDict
import aiohttp
from config import Config
from lib.config import Config
class MediaWikiApiException(Exception):
def __init__(self, info: str, code: Optional[str] = None) -> None:

@ -4,7 +4,7 @@ import json
from typing import Callable, Optional, TypedDict
import aiohttp
from config import Config
from lib.config import Config
import numpy as np
from aiohttp_sse_client2 import client as sse_client
@ -13,6 +13,13 @@ from service.tiktoken import TikTokenService
AZURE_CHATCOMPLETE_API_VERSION = "2023-07-01-preview"
AZURE_EMBEDDING_API_VERSION = "2023-05-15"
class OpenAIApiTypeInvalidException(Exception):
def __init__(self, api_id: str):
self.api_id = api_id
def __str__(self):
return f"Invalid api_id: {self.api_id}"
class ChatCompleteMessageLog(TypedDict):
role: str
content: str
@ -26,19 +33,21 @@ class ChatCompleteResponse(TypedDict):
class OpenAIApi:
@staticmethod
def create():
return OpenAIApi()
def create(api_id: str) -> OpenAIApi:
return OpenAIApi(api_id)
def __init__(self):
self.api_type = Config.get("chatcomplete.api_type", "openai", str)
self.request_proxy = Config.get("chatcomplete.request_proxy", type=str, empty_is_none=True)
def __init__(self, api_id: str):
self.api_id = api_id
if self.api_type == "azure":
self.api_url = Config.get("chatcomplete.azure.api_endpoint", type=str)
self.api_key = Config.get("chatcomplete.azure.key", type=str)
else:
self.api_url = Config.get("chatcomplete.openai.api_endpoint", type=str)
self.api_key = Config.get("chatcomplete.openai.key", type=str)
self.api_type = Config.get(f"chatcomplete.{api_id}.api_type", None, str)
if self.api_type is None:
raise OpenAIApiTypeInvalidException(api_id)
self.request_proxy = Config.get(f"chatcomplete.{api_id}.request_proxy", type=str, empty_is_none=True)
self.api_url = Config.get(f"chatcomplete.{api_id}.api_endpoint", type=str)
self.api_key = Config.get(f"chatcomplete.{api_id}.key", type=str)
def build_header(self):
if self.api_type == "azure":
@ -56,7 +65,7 @@ class OpenAIApi:
def get_url(self, method: str):
if self.api_type == "azure":
deployments = Config.get("chatcomplete.azure.deployments")
deployments = Config.get(f"chatcomplete.{self.api_id}.deployments")
if method == "chat/completions":
return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method
elif method == "embeddings":
@ -65,25 +74,10 @@ class OpenAIApi:
return self.api_url + "/v1/" + method
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
text_list = []
regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
for doc in doc_list:
text: str = doc["text"]
text = text.replace("\r\n", "\n").replace("\r", "\n")
if "\n" in text:
lines = text.split("\n")
new_lines = []
for line in lines:
line = line.strip()
# Add a dot at the end of the line if it doesn't end with a punctuation mark
if len(line) > 0 and regex.find(line[-1]) == -1:
line += "."
new_lines.append(line)
text = " ".join(new_lines)
text_list.append(text)
token_usage = 0
text_list = [doc["text"] for doc in doc_list]
async with aiohttp.ClientSession() as session:
url = self.get_url("embeddings")
params = {}

@ -0,0 +1,152 @@
from __future__ import annotations
from copy import deepcopy
import time
from lib.config import Config
import asyncio
import random
import threading
from typing import Callable, Optional, TypedDict
import torch
from text2vec import SentenceModel
from utils.local import loop
from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService
BERT_EMBEDDING_QUEUE_TIMEOUT = 1
class Text2VecEmbeddingQueueTaskInfo(TypedDict):
task_id: int
text: str
embedding: torch.Tensor
class Text2VecEmbeddingQueue:
def __init__(self, model: str) -> None:
self.model_name = model
self.embedding_model = SentenceModel(self.model_name)
self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {}
self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = []
self.lock = threading.Lock()
self.thread: Optional[threading.Thread] = None
self.running = False
async def get_embeddings(self, text: str):
task_id = random.randint(0, 1000000000)
with self.lock:
while task_id in self.task_map:
task_id = random.randint(0, 1000000000)
task_info = {
"task_id": task_id,
"text": text,
"embedding": None
}
self.task_map[task_id] = task_info
self.task_list.append(task_info)
self.start_queue()
while True:
task_info = self.pop_task(task_id)
if task_info is not None:
return task_info["embedding"]
await asyncio.sleep(0.01)
def pop_task(self, task_id):
with self.lock:
if task_id in self.task_map:
task_info = self.task_map[task_id]
if task_info["embedding"] is not None:
del self.task_map[task_id]
return task_info
return None
def run(self):
running = True
last_task_time = None
while running and self.running:
current_time = time.time()
task = None
with self.lock:
if len(self.task_list) > 0:
task = self.task_list.pop(0)
if task is not None:
embeddings = self.embedding_model.encode([task["text"]])
with self.lock:
task["embedding"] = embeddings[0]
last_task_time = time.time()
elif last_task_time is not None and current_time > last_task_time + BERT_EMBEDDING_QUEUE_TIMEOUT:
self.thread = None
self.running = False
running = False
else:
time.sleep(0.01)
def start_queue(self):
if not self.running:
self.running = True
self.thread = threading.Thread(target=self.run)
self.thread.start()
class TextEmbeddingService:
instance = None
@staticmethod
async def create() -> TextEmbeddingService:
if TextEmbeddingService.instance is None:
TextEmbeddingService.instance = TextEmbeddingService()
await TextEmbeddingService.instance.init()
return TextEmbeddingService.instance
async def init(self):
self.tiktoken = await TikTokenService.create()
self.embedding_type = Config.get("embedding.type", "text2vec")
if self.embedding_type == "text2vec":
embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese")
self.embedding_queue = Text2VecEmbeddingQueue(model=embedding_model)
elif self.embedding_type == "openai":
self.openai_api: OpenAIApi = await OpenAIApi.create()
await loop.run_in_executor(None, self.embedding_queue.init)
async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
for index, doc in enumerate(doc_list):
text = doc["text"]
embedding = await self.embedding_queue.get_embeddings(text)
doc["embedding"] = embedding
if on_index_progress is not None:
await on_index_progress(index, len(doc_list))
async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None):
res_doc_list = deepcopy(doc_list)
regex = r"[=,.?!@#$%^&*()_+:\"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]"
for doc in res_doc_list:
text: str = doc["text"]
text = text.replace("\r\n", "\n").replace("\r", "\n")
if "\n" in text:
lines = text.split("\n")
new_lines = []
for line in lines:
line = line.strip()
# Add a dot at the end of the line if it doesn't end with a punctuation mark
if len(line) > 0 and regex.find(line[-1]) == -1:
line += "."
new_lines.append(line)
text = " ".join(new_lines)
doc["text"] = text
if self.embedding_type == "text2vec":
return await self.get_text2vec_embeddings(res_doc_list, on_index_progress)
elif self.embedding_type == "openai":
return await self.openai_api.get_embeddings(res_doc_list, on_index_progress)

@ -2,9 +2,9 @@ import sys
import pathlib
root_path = pathlib.Path(__file__).parent.parent
sys.path.append(root_path)
sys.path.append(".")
from config import Config
from lib.config import Config
Config.load_config(root_path + "/config.toml")
Config.load_config(str(root_path) + "/config.toml")
Config.set("server.debug", True)

@ -1,6 +1,6 @@
import traceback
import base
import local
import utils.local as local
from service.chat_complete import ChatCompleteService
from service.database import DatabaseService

@ -3,7 +3,7 @@ import base
from sqlalchemy import select
from api.model.embedding_search.title_index import TitleIndexModel
import local
import utils.local as local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService

@ -1,5 +1,5 @@
import base
import local
import utils.local as local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService

@ -1,13 +1,21 @@
import asyncio
import time
import base
from local import loop, noawait
from service.bert_embedding import bert_embedding_queue
import base as _
from utils.local import loop, noawait
from service.text_embedding import Text2VecEmbeddingQueue
async def main():
embedding_list = []
start_time = time.time()
queue = []
text2vec_queue = Text2VecEmbeddingQueue("shibing624/text2vec-base-chinese")
start_time = time.time()
async def on_progress(current, total):
print(f"{current}/{total}")
total_lines = 0
with open("test/test.md", "r", encoding="utf-8") as fp:
text = fp.read()
lines = text.split("\n")
@ -16,10 +24,13 @@ async def main():
if line == "":
continue
queue.append(bert_embedding_queue.get_embeddings(line))
queue.append(text2vec_queue.get_embeddings(line))
total_lines += 0
embedding_list = await asyncio.gather(*queue)
end_time = time.time()
print("total lines: %d" % total_lines)
print("time cost: %.4f" % (end_time - start_time))
print("speed: %.4f it/s" % (total_lines / (end_time - start_time)))
print("dimensions: %d" % len(embedding_list[0]))
await noawait.end()

@ -1,6 +1,6 @@
import asyncio
import base
from local import loop, noawait
from utils.local import loop, noawait
async def test_timer1():
print("timer1")

@ -3,7 +3,7 @@ import base
from sqlalchemy import select
from api.model.embedding_search.title_index import TitleIndexModel
import local
import utils.local as local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService

@ -1,5 +1,5 @@
import time
from config import Config
from lib.config import Config
def get_prompt(name: str, type: str, params: dict = {}):
sys_params = {

@ -0,0 +1,6 @@
import asyncio
from lib.noawait import NoAwaitPool
debug = False
loop = asyncio.new_event_loop()
noawait = NoAwaitPool(loop)

@ -5,7 +5,7 @@ from typing import Any, Optional, Dict
from aiohttp import web
import jwt
import uuid
from config import Config
from lib.config import Config
ParamRule = Dict[str, Any]

Loading…
Cancel
Save