From 071ab94829d69daa57c98936d8176e1cd3d3f273 Mon Sep 17 00:00:00 2001 From: Lex Lim Date: Wed, 19 Jun 2024 06:27:32 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=AD=A3gitignore=E6=8E=92=E9=99=A4?= =?UTF-8?q?=E7=9B=AE=E5=BD=95=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- init/server/chat_complete.py | 5 +- init/service/mw_api.py | 2 +- init/service/tiktoken.py | 2 +- libs/config.py | 95 ++++++++++ libs/noawait.py | 165 ++++++++++++++++++ main.py | 12 +- maintenance/base.py | 7 +- server/controller/ChatComplete.py | 71 +++++--- server/controller/EmbeddingSearch.py | 19 +- server/controller/task/ChatCompleteTask.py | 29 ++- server/model/chat_complete/bot_persona.py | 38 ++-- .../chat_complete/bot_persona_category.py | 43 ----- server/model/embedding_search/page_index.py | 2 +- server/model/embedding_search/title_index.py | 2 +- server/route/chat_complete.py | 2 +- service/chat_complete.py | 55 ++++-- service/database.py | 2 +- service/embedding_search.py | 11 +- service/mediawiki_api.py | 78 ++++----- service/openai_api.py | 26 ++- service/text_embedding.py | 31 +++- test/base.py | 2 +- test/chatcomplete.py | 2 +- test/embedding_search.py | 4 +- utils/config.py | 2 +- utils/local.py | 2 +- utils/web.py | 4 +- 27 files changed, 502 insertions(+), 211 deletions(-) create mode 100755 libs/config.py create mode 100644 libs/noawait.py delete mode 100644 server/model/chat_complete/bot_persona_category.py diff --git a/init/server/chat_complete.py b/init/server/chat_complete.py index 85c98bf..764bdf9 100644 --- a/init/server/chat_complete.py +++ b/init/server/chat_complete.py @@ -5,12 +5,11 @@ import init.service.mw_api as _ # Init mediawiki api import init.service.database as _ # Init database import init.service.tiktoken as _ # Init tiktoken -import toolbox_ui as _ # Init toolbox ui -import embedding_search as _ # Init embedding search +import init.server.toolbox_ui as _ # Init toolbox ui +import init.server.embedding_search as _ # Init embedding search # Auto create database tables from server.model.chat_complete.conversation import ConversationChunkModel as _ -from server.model.chat_complete.bot_persona_category import BotPersonaCategoryModel as _ from server.model.chat_complete.bot_persona import BotPersonaModel as _ # Route diff --git a/init/service/mw_api.py b/init/service/mw_api.py index b2c28af..f7df49f 100644 --- a/init/service/mw_api.py +++ b/init/service/mw_api.py @@ -2,7 +2,7 @@ from __future__ import annotations from aiohttp import web from utils.server import register_server_module -from lib.config import Config +from libs.config import Config from service.mediawiki_api import MediaWikiApi diff --git a/init/service/tiktoken.py b/init/service/tiktoken.py index 780712c..fb70b99 100644 --- a/init/service/tiktoken.py +++ b/init/service/tiktoken.py @@ -9,7 +9,7 @@ async def init_tiktoken(app: web.Application): print("Tiktoken model loaded.") -async def init(app: web.Application): +def init(app: web.Application): app.on_startup.append(init_tiktoken) diff --git a/libs/config.py b/libs/config.py new file mode 100755 index 0000000..7d8323e --- /dev/null +++ b/libs/config.py @@ -0,0 +1,95 @@ +from __future__ import annotations +import os +import os.path +import toml +from watchdog.events import FileSystemEventHandler +from watchdog.observers import Observer + +config_observer = Observer() + +class Config: + file_name: str + values: dict = {} + + @staticmethod + def load_config(file): + global config_observer + + Config.file_name = file + Config.reload() + + if config_observer.is_alive(): + config_observer.stop() + config_observer.join() + + config_observer.schedule(Config.ConfigEventHandler(), file) + config_observer.start() + + + @staticmethod + def reload(): + with open(Config.file_name, "r", encoding="utf-8") as f: + Config.values = toml.load(f) + + + @staticmethod + def get(key: str, default=None, type=None, empty_is_none=False): + key_path = key.split(".") + value = Config.values + for k in key_path: + if k in value: + value = value[k] + else: + return default + + if empty_is_none and value == "": + return None + + if type == bool: + if isinstance(value, bool): + return value + elif isinstance(value, int) or isinstance(value, float): + return value != 0 + else: + return str(value).lower() in ("yes", "true", "1") + elif type == int: + return int(value) + elif type == float: + return float(value) + elif type == str: + return str(value) + elif type == list: + if not isinstance(value, list): + return [] + elif type == dict: + if not isinstance(value, dict): + return {} + else: + return value + + + @staticmethod + def set(key: str, value): + key_path = key.split(".") + obj = Config.values + for k in key_path[:-1]: + if k not in obj: + obj[k] = {} + obj = obj[k] + obj[key_path[-1]] = value + + + @staticmethod + def unload(): + global config_observer + + if config_observer.is_alive(): + config_observer.stop() + config_observer.join() + + + class ConfigEventHandler(FileSystemEventHandler): + def on_modified(self, event): + if os.path.basename(event.src_path) == os.path.basename(Config.file_name): + print("Config file changed, reloading...") + Config.reload() \ No newline at end of file diff --git a/libs/noawait.py b/libs/noawait.py new file mode 100644 index 0000000..e0e6c20 --- /dev/null +++ b/libs/noawait.py @@ -0,0 +1,165 @@ +from __future__ import annotations +from asyncio import AbstractEventLoop, Task +import asyncio +import atexit +from functools import wraps +import random + +import sys +import traceback +from typing import Callable, Coroutine, Optional, TypedDict + +class TimerInfo(TypedDict): + id: int + callback: Callable + interval: float + next_time: float + +class NoAwaitPool: + def __init__(self, loop: AbstractEventLoop): + self.task_list: list[Task] = [] + self.timer_map: dict[int, TimerInfo] = {} + self.loop = loop + self.running = True + + self.should_refresh_task = False + self.next_timer_time: Optional[float] = None + + self.on_error: list[Callable] = [] + + self.gc_task = loop.create_task(self._run_gc()) + self.timer_task = loop.create_task(self._run_timer()) + + atexit.register(self.end_task) + + async def end(self): + if self.running: + print("Stopping NoAwait Tasks...") + self.running = False + for task in self.task_list: + await self._finish_task(task) + + await self.gc_task + await self.timer_task + + def end_task(self): + if self.running and not self.loop.is_closed(): + self.loop.run_until_complete(self.end()) + + async def _wrap_task(self, task: Task): + try: + await task + except Exception as e: + handled = False + for handler in self.on_error: + try: + handler_ret = handler(e) + await handler_ret + handled = True + except Exception as handler_err: + print("Exception on error handler: " + str(handler_err), file=sys.stderr) + traceback.print_exc() + + if not handled: + print(e, file=sys.stderr) + traceback.print_exc() + finally: + self.should_refresh_task = True + + def add_task(self, coroutine: Coroutine): + task = self.loop.create_task(coroutine) + self.task_list.append(task) + + def add_timer(self, callback: Callable, interval: float) -> int: + id = random.randint(0, 1000000000) + while id in self.timer_map: + id = random.randint(0, 1000000000) + + now = self.loop.time() + next_time = now + interval + self.timer_map[id] = { + "id": id, + "callback": callback, + "interval": interval, + "next_time": next_time + } + + if self.next_timer_time is None or next_time < self.next_timer_time: + self.next_timer_time = next_time + + return id + + def remove_timer(self, id: int): + if id in self.timer_map: + del self.timer_map[id] + + def wrap(self, f): + @wraps(f) + def decorated_function(*args, **kwargs): + coroutine = f(*args, **kwargs) + self.add_task(coroutine) + + return decorated_function + + async def _finish_task(self, task: Task): + try: + if not task.done(): + task.cancel() + await task + except Exception as e: + handled = False + for handler in self.on_error: + try: + handler_ret = handler(e) + await handler_ret + handled = True + except Exception as handler_err: + print("Exception on error handler: " + str(handler_err), file=sys.stderr) + traceback.print_exc() + + if not handled: + print(e, file=sys.stderr) + traceback.print_exc() + + async def _run_gc(self): + while self.running: + if self.should_refresh_task: + should_remove = [] + for task in self.task_list: + if task.done(): + await self._finish_task(task) + should_remove.append(task) + for task in should_remove: + self.task_list.remove(task) + + await asyncio.sleep(0.1) + + async def _run_timer(self): + while self.running: + now = self.loop.time() + if self.next_timer_time is not None and now >= self.next_timer_time: + self.next_timer_time = None + for timer in self.timer_map.values(): + if now >= timer["next_time"]: + timer["next_time"] = now + timer["interval"] + try: + result = timer["callback"]() + self.add_task(result) + except Exception as e: + handled = False + for handler in self.on_error: + try: + handler_ret = handler(e) + self.add_task(handler_ret) + handled = True + except Exception as handler_err: + print("Exception on error handler: " + str(handler_err), file=sys.stderr) + traceback.print_exc() + + if not handled: + print(e, file=sys.stderr) + traceback.print_exc() + if self.next_timer_time is None or timer["next_time"] < self.next_timer_time: + self.next_timer_time = timer["next_time"] + + await asyncio.sleep(0.1) \ No newline at end of file diff --git a/main.py b/main.py index 7a078e0..9d9e721 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,6 @@ import sys import traceback -from lib.config import Config +from libs.config import Config from utils.server import get_server_modules Config.load_config("config.toml") @@ -8,9 +8,7 @@ Config.load_config("config.toml") from utils.local import loop, noawait from aiohttp import web import utils.local as local -import server.route as api_route import utils.web -from service.mediawiki_api import MediaWikiApi async def index(request: web.Request): @@ -53,8 +51,9 @@ async def error_handler(request, handler): ) -async def stop_noawait_pool(app: web.Application): +async def on_shutdown(app: web.Application): await noawait.end() + Config.unload() if __name__ == "__main__": local.debug = Config.get("server.debug", False, bool) @@ -79,15 +78,14 @@ if __name__ == "__main__": # Initialize server modules server_modules = get_server_modules() - print(server_modules) for server_module in server_modules: + print("Loading server module: " + server_module["name"]) server_module["init"](app) - app.on_shutdown.append(stop_noawait_pool) + app.on_shutdown.append(on_shutdown) app.router.add_route("*", "/", index) - api_route.init(app) server_port = Config.get("port", 8144, int) diff --git a/maintenance/base.py b/maintenance/base.py index bb81f7d..c661335 100644 --- a/maintenance/base.py +++ b/maintenance/base.py @@ -1,4 +1,9 @@ import sys import pathlib -sys.path.append(str(pathlib.Path(__file__).parent.parent)) \ No newline at end of file +root_path = pathlib.Path(__file__).parent.parent +sys.path.append(".") + +from libs.config import Config + +Config.load_config(str(root_path) + "/config.toml") \ No newline at end of file diff --git a/server/controller/ChatComplete.py b/server/controller/ChatComplete.py index 2a21f4a..87c0647 100644 --- a/server/controller/ChatComplete.py +++ b/server/controller/ChatComplete.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import time import traceback +from libs.config import Config from server.controller.task.ChatCompleteTask import ChatCompleteTask from server.model.base import clone_model from server.model.chat_complete.bot_persona import BotPersonaHelper @@ -9,9 +10,9 @@ from server.model.toolbox_ui.conversation import ConversationHelper from utils.local import noawait from aiohttp import web from server.model.chat_complete.conversation import ConversationChunkHelper, ConversationModel, ConversationChunkModel -from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse +from service.chat_complete import ChatCompleteQuestionTooLongException, ChatCompleteServiceResponse, calculate_point_usage from service.database import DatabaseService -from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException +from service.mediawiki_api import MediaWikiPageNotFoundException, MediaWikiUserNoEnoughPointsException from service.tiktoken import TikTokenService import utils.web @@ -63,6 +64,7 @@ class ChatComplete: return await utils.web.api_response(1, conversation_chunk_list, request=request) + @staticmethod @utils.web.token_auth async def get_conversation_chunk(request: web.Request): @@ -111,6 +113,7 @@ class ChatComplete: return await utils.web.api_response(1, chunk_dict, request=request) + @staticmethod @utils.web.token_auth async def fork_conversation(request: web.Request): @@ -223,6 +226,7 @@ class ChatComplete: "conversation_id": new_conversation.id, }, request=request) + @staticmethod @utils.web.token_auth async def get_tokens(request: web.Request): @@ -240,14 +244,19 @@ class ChatComplete: return await utils.web.api_response(1, {"tokens": tokens}, request=request) + @staticmethod @utils.web.token_auth - async def get_point_cost(request: web.Request): + async def get_point_usage(request: web.Request): params = await utils.web.get_param(request, { "question": { "type": str, "required": True, }, + "bot_id": { + "type": str, + "required": True, + }, "extract_limit": { "type": int, "required": False, @@ -256,39 +265,41 @@ class ChatComplete: }) user_id = request.get("user") - caller = request.get("caller") question = params.get("question") - extract_limit = params.get("extract_limit") + bot_id = params.get("bot_id") tiktoken = await TikTokenService.create() - mwapi = MediaWikiApi.create() + db = await DatabaseService.create(request.app) tokens = await tiktoken.get_tokens(question) - try: - res = await mwapi.ai_toolbox_get_point_cost(user_id, "chatcomplete", tokens, extract_limit) + estimated_extract_tokens_per_doc = Config.get("estimated_extract_tokens_per_doc", 50, int) + + predict_tokens = tokens + estimated_extract_tokens_per_doc + + async with BotPersonaHelper(db) as bot_persona_helper: + persona_info = await bot_persona_helper.find_by_bot_id(bot_id) + + if persona_info is None: + return await utils.web.api_response(-1, error={ + "code": "bot-not-found", + "message": "Bot not found." + }, http_status=404, request=request) + + point_usage = calculate_point_usage(predict_tokens, persona_info.cost_fixed, persona_info.cost_fixed_tokens, persona_info.cost_per_token) + return await utils.web.api_response(1, { - "point_cost": res["point_cost"], + "point_usage": point_usage, "tokens": tokens, + "predict_tokens": predict_tokens, }, request=request) - except Exception as e: - err_msg = f"Error while get chat complete point cost: {e}" - traceback.print_exc() - - return await utils.web.api_response(-1, error={ - "code": "internal-server-error", - "message": err_msg - }, http_status=500, request=request) - + + @staticmethod @utils.web.token_auth async def get_persona_list(request: web.Request): params = await utils.web.get_param(request, { - "category_id": { - "type": int, - "required": False, - }, "page": { "type": int, "required": False, @@ -296,13 +307,12 @@ class ChatComplete: } }) - category_id = params.get("category_id") page = params.get("page") db = await DatabaseService.create(request.app) async with BotPersonaHelper(db) as bot_persona_helper: - persona_list = await bot_persona_helper.get_list(page=page, category_id=category_id) - page_count = await bot_persona_helper.get_page_count(category_id=category_id) + persona_list = await bot_persona_helper.get_list(page=page) + page_count = await bot_persona_helper.get_page_count() persona_data_list = [] for persona in persona_list: @@ -312,13 +322,16 @@ class ChatComplete: "bot_name": persona.bot_name, "bot_avatar": persona.bot_avatar, "bot_description": persona.bot_description, - "updated_at": persona.updated_at, + "model_id": persona.model_id, + "model_name": persona.model_name, + "cost_fixed": persona.cost_fixed }) return await utils.web.api_response(1, { "list": persona_data_list, "page_count": page_count, }, request=request) + @staticmethod @utils.web.token_auth @@ -353,6 +366,7 @@ class ChatComplete: persona_info_res[key] = value return await utils.web.api_response(1, persona_info_res, request=request) + @staticmethod @utils.web.token_auth @@ -372,7 +386,7 @@ class ChatComplete: }, "bot_id": { "type": str, - "required": False, + "required": True, }, "extract_limit": { "type": int, @@ -453,6 +467,7 @@ class ChatComplete: "message": err_msg }, http_status=500, request=request) + @staticmethod @utils.web.token_auth async def chat_complete_stream(request: web.Request): @@ -541,7 +556,7 @@ class ChatComplete: try: ignored_keys = ["message"] response_result = { - "point_cost": task.point_cost, + "point_usage": task.point_usage, } for k, v in result.items(): if k not in ignored_keys: diff --git a/server/controller/EmbeddingSearch.py b/server/controller/EmbeddingSearch.py index 6b3f976..a47a568 100644 --- a/server/controller/EmbeddingSearch.py +++ b/server/controller/EmbeddingSearch.py @@ -1,6 +1,7 @@ import sys import traceback from aiohttp import web +from libs.config import Config from server.model.embedding_search.title_collection import TitleCollectionHelper from server.model.embedding_search.title_index import TitleIndexHelper from service.database import DatabaseService @@ -53,10 +54,10 @@ class EmbeddingSearch: page_current += 1 async with EmbeddingSearchService(db, one_title) as embedding_search: if await embedding_search.should_update_page_index(): - if request.get("caller") == "user": - user_id = request.get("user") - usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage") - transatcion_id = usage_res.get("transaction_id") + # if request.get("caller") == "user": + # user_id = request.get("user") + # usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage") + # transatcion_id = usage_res.get("transaction_id") await embedding_search.prepare_update_index() @@ -144,10 +145,10 @@ class EmbeddingSearch: async with EmbeddingSearchService(db, page_title) as embedding_search: if await embedding_search.should_update_page_index(): - if request.get("caller") == "user": - user_id = request.get("user") - usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage") - transatcion_id = usage_res.get("transaction_id") + # if request.get("caller") == "user": + # user_id = request.get("user") + # usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage") + # transatcion_id = usage_res.get("transaction_id") await embedding_search.prepare_update_index() @@ -236,7 +237,7 @@ class EmbeddingSearch: "distancelimit": { "required": False, "type": float, - "default": 0.6 + "default": None }, }) diff --git a/server/controller/task/ChatCompleteTask.py b/server/controller/task/ChatCompleteTask.py index c38c0c6..a3a953f 100644 --- a/server/controller/task/ChatCompleteTask.py +++ b/server/controller/task/ChatCompleteTask.py @@ -2,12 +2,15 @@ from __future__ import annotations import sys import time import traceback +from libs.config import Config +from server.model.chat_complete.bot_persona import BotPersonaHelper from utils.local import noawait from typing import Optional, Callable, Union from service.chat_complete import ( ChatCompleteService, ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse, + calculate_point_usage, ) from service.database import DatabaseService from service.embedding_search import EmbeddingSearchArgs @@ -41,7 +44,7 @@ class ChatCompleteTask: self.is_system = is_system self.transatcion_id: Optional[str] = None - self.point_cost: int = 0 + self.point_usage: int = 0 self.is_finished = False self.finished_time: Optional[float] = None @@ -51,9 +54,9 @@ class ChatCompleteTask: async def init( self, question: str, + bot_id: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None, - bot_id: Optional[str] = None, embedding_search: Optional[EmbeddingSearchArgs] = None, ) -> ChatCompleteServicePrepareResponse: self.tiktoken = await TikTokenService.create() @@ -69,19 +72,31 @@ class ChatCompleteTask: extract_limit = embedding_search["limit"] or 10 + estimated_extract_tokens_per_doc = Config.get("estimated_extract_tokens_per_doc", 50, int) + + predict_tokens = extract_limit + estimated_extract_tokens_per_doc + + async with BotPersonaHelper(self.dbs) as bot_persona_helper: + bot_persona = await bot_persona_helper.find_by_bot_id(bot_id) + + self.point_usage: int = calculate_point_usage(predict_tokens, bot_persona.cost_fixed, + bot_persona.cost_fixed_tokens, bot_persona.cost_per_token) + self.transatcion_id: Optional[str] = None - self.point_cost: int = 0 if not self.is_system: usage_res = await self.mwapi.ai_toolbox_start_transaction( - self.user_id, "chatcomplete", question_tokens, extract_limit + self.user_id, "chatcomplete", + bot_id=bot_id, + tokens=predict_tokens, + point_usage=self.point_usage ) self.transatcion_id = usage_res["transaction_id"] - self.point_cost = usage_res["point_cost"] res = await self.chat_complete.prepare_chat_complete( question, conversation_id=conversation_id, user_id=self.user_id, + question_tokens=question_tokens, edit_message_id=edit_message_id, bot_id=bot_id, embedding_search=embedding_search, @@ -136,13 +151,13 @@ class ChatCompleteTask: try: chat_res = await self.chat_complete.finish_chat_complete(self._on_message) - await self.chat_complete.set_latest_point_cost(self.point_cost) + await self.chat_complete.set_latest_point_usage(self.point_usage) self.result = chat_res if self.transatcion_id: await self.mwapi.ai_toolbox_end_transaction( - self.transatcion_id, chat_res["total_tokens"] + self.transatcion_id, self.point_usage # TODO: 根据实际使用Tokens扣除积分 ) await self._on_finished() diff --git a/server/model/chat_complete/bot_persona.py b/server/model/chat_complete/bot_persona.py index 24b1d24..8b01ecf 100644 --- a/server/model/chat_complete/bot_persona.py +++ b/server/model/chat_complete/bot_persona.py @@ -5,10 +5,9 @@ import sqlalchemy from server.model.base import BaseHelper, BaseModel import sqlalchemy -from sqlalchemy import select, update -from sqlalchemy.orm import mapped_column, relationship, load_only, Mapped +from sqlalchemy import select +from sqlalchemy.orm import mapped_column, load_only, Mapped -from server.model.chat_complete.bot_persona_category import BotPersonaCategoryModel from service.database import DatabaseService @@ -22,17 +21,16 @@ class BotPersonaModel(BaseModel): bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) - category_id: Mapped[int] = mapped_column( - sqlalchemy.ForeignKey( - BotPersonaCategoryModel.id, ondelete="CASCADE", onupdate="CASCADE" - ), - index=True, - ) api_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) + model_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True, index=True) + model_name: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) system_prompt: Mapped[str] = mapped_column(sqlalchemy.String) - message_log: Mapped[list] = mapped_column(sqlalchemy.JSON) + message_log: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True) default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) - updated_at: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) + cost_fixed: Mapped[float] = mapped_column(sqlalchemy.Float, nullable=True) + cost_fixed_tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) + cost_per_token: Mapped[float] = mapped_column(sqlalchemy.Float, nullable=True) + order: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, nullable=True) class BotPersonaHelper(BaseHelper): @@ -50,8 +48,7 @@ class BotPersonaHelper(BaseHelper): async def get_list( self, page: Optional[int] = 1, - page_size: Optional[int] = 20, - category_id: Optional[int] = None, + page_size: Optional[int] = 20 ): offset_index = (page - 1) * page_size @@ -64,25 +61,22 @@ class BotPersonaHelper(BaseHelper): BotPersonaModel.bot_name, BotPersonaModel.bot_avatar, BotPersonaModel.bot_description, - BotPersonaModel.updated_at, + BotPersonaModel.model_id, + BotPersonaModel.model_name, + BotPersonaModel.cost_fixed, + BotPersonaModel.order, ) ) - .order_by(BotPersonaModel.updated_at.desc()) + .order_by(BotPersonaModel.order.desc()) .offset(offset_index) .limit(page_size) ) - if category_id is not None: - stmt = stmt.where(BotPersonaModel.category_id == category_id) - return await self.session.scalars(stmt) - async def get_page_count(self, page_size=50, category_id: Optional[int] = None): + async def get_page_count(self, page_size=50): stmt = select(sqlalchemy.func.count()).select_from(BotPersonaModel) - if category_id is not None: - stmt = stmt.where(BotPersonaModel.category_id == category_id) - item_count = await self.session.scalar(stmt) if item_count is None: item_count = 0 diff --git a/server/model/chat_complete/bot_persona_category.py b/server/model/chat_complete/bot_persona_category.py deleted file mode 100644 index a541dea..0000000 --- a/server/model/chat_complete/bot_persona_category.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations -import sqlalchemy -from server.model.base import BaseHelper, BaseModel - -import sqlalchemy -from sqlalchemy import select, update -from sqlalchemy.orm import mapped_column, relationship, Mapped - -class BotPersonaCategoryModel(BaseModel): - __tablename__ = "chat_complete_bot_persona_category" - - id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) - name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) - description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) - icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) - font_icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) - color: Mapped[str] = mapped_column(sqlalchemy.String(10), nullable=True) - order: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) - -class BotPersonaHelper(BaseHelper): - async def add(self, obj: BotPersonaCategoryModel): - self.session.add(obj) - await self.session.commit() - await self.session.refresh(obj) - return obj - - async def update(self, obj: BotPersonaCategoryModel): - obj = await self.session.merge(obj) - await self.session.commit() - return obj - - async def remove(self, id: int): - stmt = sqlalchemy.delete(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id) - await self.session.execute(stmt) - await self.session.commit() - - async def get_list(self): - stmt = select(BotPersonaCategoryModel) - return await self.session.scalars(stmt) - - async def find_by_id(self, id: int): - stmt = select(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id) - return await self.session.scalar(stmt) \ No newline at end of file diff --git a/server/model/embedding_search/page_index.py b/server/model/embedding_search/page_index.py index a9c9e29..5862bdd 100644 --- a/server/model/embedding_search/page_index.py +++ b/server/model/embedding_search/page_index.py @@ -6,7 +6,7 @@ import asyncpg from server.model.base import BaseModel import numpy as np import sqlalchemy -from lib.config import Config +from libs.config import Config from sqlalchemy import Index, select, update, delete, Select from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.ext.asyncio import AsyncSession diff --git a/server/model/embedding_search/title_index.py b/server/model/embedding_search/title_index.py index 405db21..57c1a84 100644 --- a/server/model/embedding_search/title_index.py +++ b/server/model/embedding_search/title_index.py @@ -8,7 +8,7 @@ import sqlalchemy from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer from sqlalchemy.ext.asyncio import AsyncEngine -from lib.config import Config +from libs.config import Config from server.model.base import BaseHelper, BaseModel from service.database import DatabaseService diff --git a/server/route/chat_complete.py b/server/route/chat_complete.py index 0c74881..9a8a45f 100644 --- a/server/route/chat_complete.py +++ b/server/route/chat_complete.py @@ -9,7 +9,7 @@ def register_route(app: web.Application): web.route('POST', '/chatcomplete/conversation/fork', ChatComplete.fork_conversation), web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete), web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream), - web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost), + web.route('POST', '/chatcomplete/get_point_usage', ChatComplete.get_point_usage), web.route('*', '/chatcomplete/persona/list', ChatComplete.get_persona_list), web.route('*', '/chatcomplete/persona/info', ChatComplete.get_persona_info), ]) \ No newline at end of file diff --git a/service/chat_complete.py b/service/chat_complete.py index 1c573b7..12b47b2 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -1,4 +1,5 @@ from __future__ import annotations +import math import time import traceback from typing import Optional, Tuple, TypedDict @@ -11,7 +12,7 @@ from server.model.chat_complete.conversation import ( import sys from server.model.toolbox_ui.conversation import ConversationHelper, ConversationModel -from lib.config import Config +from libs.config import Config import utils.config, utils.web from aiohttp import web @@ -72,6 +73,7 @@ class ChatCompleteService: self.question = "" self.question_tokens: Optional[int] = None self.bot_id: str = "" + self.model: Optional[str] = None self.conversation_id: Optional[int] = None self.conversation_start_time: Optional[int] = None @@ -104,11 +106,11 @@ class ChatCompleteService: async def prepare_chat_complete( self, question: str, + bot_id: str, conversation_id: Optional[str] = None, user_id: Optional[int] = None, question_tokens: Optional[int] = None, edit_message_id: Optional[str] = None, - bot_id: Optional[str] = None, embedding_search: Optional[EmbeddingSearchArgs] = None, ) -> ChatCompleteServicePrepareResponse: if user_id is not None: @@ -117,7 +119,7 @@ class ChatCompleteService: self.user_id = user_id self.question = question self.conversation_start_time = int(time.time()) - self.bot_id = bot_id or None + self.bot_id = bot_id self.conversation_info = None if conversation_id is not None: @@ -153,14 +155,17 @@ class ChatCompleteService: if bot_persona is None: self.bot_id = "default" + self.model = None bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id) else: + self.model = bot_persona.model_id self.chat_system_prompt = bot_persona.system_prompt default_api = Config.get("chatcomplete.default_api", None, str) try: self.openai_api = OpenAIApi.create(bot_persona.api_id or default_api) except OpenAIApiTypeInvalidException: + print(f"Invalid API type: {bot_persona.api_id}", file=sys.stderr) self.openai_api = OpenAIApi.create(default_api) self.conversation_chunk = None @@ -236,10 +241,11 @@ class ChatCompleteService: init_message_data = [] if bot_persona is not None: current_time = int(time.time()) - for message in bot_persona.message_log: - message["id"] = utils.web.generate_uuid() - message["time"] = current_time - init_message_data.append(message) + if bot_persona.message_log is not None: + for message in bot_persona.message_log: + message["id"] = utils.web.generate_uuid() + message["time"] = current_time + init_message_data.append(message) title_info = self.embedding_search.title_index self.conversation_info = ConversationModel( @@ -319,11 +325,11 @@ class ChatCompleteService: # Start chat complete if on_message is not None: response = await self.openai_api.chat_complete_stream( - self.question, system_prompt, message_log, on_message + self.question, system_prompt, self.model, message_log, on_message ) else: response = await self.openai_api.chat_complete( - self.question, system_prompt, message_log + self.question, system_prompt, self.model, message_log ) description = response["message"][0:150] @@ -384,7 +390,7 @@ class ChatCompleteService: delta_data=delta_data, ) - async def set_latest_point_cost(self, point_cost: int) -> bool: + async def set_latest_point_usage(self, point_usage: int) -> bool: if self.conversation_chunk is None: return False @@ -393,7 +399,7 @@ class ChatCompleteService: for i in range(len(self.conversation_chunk.message_data) - 1, -1, -1): if self.conversation_chunk.message_data[i].get("role") == "assistant": - self.conversation_chunk.message_data[i]["point_cost"] = point_cost + self.conversation_chunk.message_data[i]["point_usage"] = point_usage flag_modified(self.conversation_chunk, "message_data") await self.conversation_chunk_helper.update(self.conversation_chunk) @@ -402,6 +408,10 @@ class ChatCompleteService: async def make_summary(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str) + api_id = Config.get("chatcomplete.system_api_id", "default", str) + model_id = Config.get("chatcomplete.system_api_id", "default", str) + + openai_api = OpenAIApi.create(api_id) for message_data in message_log_list: if "content" in message_data: @@ -422,8 +432,8 @@ class ChatCompleteService: "make_summary", "prompt", {"content": chat_log_str} ) - response = await self.openai_api.chat_complete( - summary_prompt, summary_system_prompt + response = await openai_api.chat_complete( + summary_prompt, summary_system_prompt, model_id ) return response["message"], response["message_tokens"] @@ -431,6 +441,10 @@ class ChatCompleteService: async def make_title(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str) + api_id = Config.get("chatcomplete.system_api_id", "default", str) + model_id = Config.get("chatcomplete.system_api_id", "default", str) + + openai_api = OpenAIApi.create(api_id) for message_data in message_log_list: if "content" in message_data: @@ -451,7 +465,18 @@ class ChatCompleteService: "make_title", "prompt", {"content": chat_log_str} ) - response = await self.openai_api.chat_complete( - title_prompt, title_system_prompt + response = await openai_api.chat_complete( + title_prompt, title_system_prompt, model_id ) + + if response["message"] is None: + print(response) + raise Exception("Title generation failed") + return response["message"][0:250], response["message_tokens"] + + +def calculate_point_usage(tokens: int, cost_fixed: float, cost_fixed_tokens: int, cost_per_token: float) -> int: + if tokens <= cost_fixed_tokens: + return cost_fixed + return cost_fixed + math.ceil((tokens - cost_fixed_tokens) * cost_per_token) \ No newline at end of file diff --git a/service/database.py b/service/database.py index 6d8f1e9..bc6c3a1 100644 --- a/service/database.py +++ b/service/database.py @@ -4,7 +4,7 @@ from urllib.parse import quote_plus from aiohttp import web import asyncpg from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine -from lib.config import Config +from libs.config import Config def get_dsn(): db_conf = Config.get("database") diff --git a/service/embedding_search.py b/service/embedding_search.py index aea6193..288c218 100644 --- a/service/embedding_search.py +++ b/service/embedding_search.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Optional, TypedDict import sqlalchemy +from libs.config import Config from server.model.embedding_search.title_collection import ( TitleCollectionHelper, TitleCollectionModel, @@ -10,7 +11,6 @@ from server.model.embedding_search.title_index import TitleIndexHelper, TitleInd from server.model.embedding_search.page_index import PageIndexHelper from service.database import DatabaseService from service.mediawiki_api import MediaWikiApi -from service.openai_api import OpenAIApi from service.text_embedding import TextEmbeddingService from service.tiktoken import TikTokenService from utils.wiki import getWikiSentences @@ -43,7 +43,6 @@ class EmbeddingSearchService: self.tiktoken: TikTokenService = None self.mwapi = MediaWikiApi.create() - self.openai_api = OpenAIApi.create() self.page_id: Optional[int] = None self.collection_id: Optional[int] = None @@ -284,10 +283,13 @@ class EmbeddingSearchService: query: str, limit: int = 10, in_collection: bool = False, - distance_limit: float = 0.6, + distance_limit: Optional[float] = None, ): if limit == 0: return [], 0 + + if distance_limit is None: + distance_limit = Config.get("embedding.default_distance_limit") if self.page_index is None: raise Exception("Page index is not initialized") @@ -296,12 +298,15 @@ class EmbeddingSearchService: query_doc, token_usage = await self.text_embedding.get_embeddings(query_doc) query_embedding = query_doc[0]["embedding"] + print(query_embedding) + if query_embedding is None: return [], token_usage res = await self.page_index.search_text_embedding( query_embedding, in_collection, limit, self.page_id ) + print(res) if res: filtered = [] for one in res: diff --git a/service/mediawiki_api.py b/service/mediawiki_api.py index 24cccbe..3a135e6 100644 --- a/service/mediawiki_api.py +++ b/service/mediawiki_api.py @@ -4,7 +4,7 @@ import sys import time from typing import Optional, TypedDict import aiohttp -from lib.config import Config +from libs.config import Config class MediaWikiApiException(Exception): def __init__(self, info: str, code: Optional[str] = None) -> None: @@ -41,10 +41,9 @@ class GetAllPagesResponse(TypedDict): continue_key: Optional[str] class ChatCompleteGetPointUsageResponse(TypedDict): - point_cost: int + point_usage: int class ChatCompleteReportUsageResponse(TypedDict): - point_cost: int transaction_id: str class MediaWikiApi: @@ -58,6 +57,7 @@ class MediaWikiApi: return MediaWikiApi.instance + def __init__(self, api_url: str): self.api_url = api_url self.request_proxy = Config.get("request.proxy", type=str, empty_is_none=True) @@ -66,6 +66,7 @@ class MediaWikiApi: self.login_identity = None self.login_time = 0.0 + async def get_page_info(self, title: str): async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: params = { @@ -85,7 +86,8 @@ class MediaWikiApi: raise MediaWikiPageNotFoundException(data["error"]["info"], data["error"]["code"]) return data["query"]["pages"][0] - + + async def parse_page(self, title: str): async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: params = { @@ -104,7 +106,8 @@ class MediaWikiApi: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) return data["parse"]["text"] - + + async def get_site_meta(self): async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: params = { @@ -130,7 +133,8 @@ class MediaWikiApi: ret["user"] = data["query"]["userinfo"]["name"] return ret - + + async def get_all_pages(self, continue_key: Optional[str] = None, start_title: Optional[str] = None) -> GetAllPagesResponse: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: params = { @@ -162,7 +166,8 @@ class MediaWikiApi: ret["continue_key"] = data["continue"]["apcontinue"] return ret - + + async def is_logged_in(self): async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: params = { @@ -177,7 +182,8 @@ class MediaWikiApi: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) return data["query"]["userinfo"]["id"] != 0 - + + async def get_token(self, token_type: str): async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: params = { @@ -193,7 +199,8 @@ class MediaWikiApi: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) return data["query"]["tokens"][token_type + "token"] - + + async def robot_login(self, username: str, password: str): if await self.is_logged_in(): self.login_time = time.time() @@ -224,7 +231,8 @@ class MediaWikiApi: } return True - + + async def refresh_login(self): if self.login_identity is None: print("刷新MW机器人账号登录状态失败:没有保存的用户") @@ -232,7 +240,8 @@ class MediaWikiApi: if time.time() - self.login_time > 3600: print("刷新MW机器人账号登录状态") return await self.robot_login(self.login_identity["username"], self.login_identity["password"]) - + + async def search_title(self, keyword: str) -> list[str]: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: params = { @@ -244,7 +253,8 @@ class MediaWikiApi: async with session.get(self.api_url, params=params, proxy=self.request_proxy) as resp: data = await resp.json() return data[1] - + + async def ai_toolbox_get_user_info(self, user_id: int): await self.refresh_login() @@ -266,32 +276,9 @@ class MediaWikiApi: return data["aitoolboxbot"]["userinfo"] - async def ai_toolbox_get_point_cost(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteGetPointUsageResponse: - await self.refresh_login() - async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: - post_data = { - "action": "aitoolboxbot", - "method": "reportusage", - "step": "check", - "userid": int(user_id) if user_id is not None else None, - "useraction": user_action, - "tokens": int(tokens) if tokens is not None else None, - "extractlines": int(extractlines) if extractlines is not None else None, - "format": "json", - "formatversion": "2", - } - # Filter out None values - post_data = {k: v for k, v in post_data.items() if v is not None} - async with session.post(self.api_url, data=post_data, proxy=self.request_proxy) as resp: - data = await resp.json() - if "error" in data: - raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) - - point_cost = int(data["aitoolboxbot"]["reportusage"]["pointcost"] or 0) - return ChatCompleteGetPointUsageResponse(point_cost=point_cost) - - async def ai_toolbox_start_transaction(self, user_id: int, user_action: str, tokens: Optional[int] = None, extractlines: Optional[int] = None) -> ChatCompleteReportUsageResponse: + async def ai_toolbox_start_transaction(self, user_id: int, user_action: str, bot_id: Optional[str] = None, + tokens: Optional[int] = None, point_usage: Optional[int] = None) -> ChatCompleteReportUsageResponse: await self.refresh_login() async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: @@ -301,8 +288,9 @@ class MediaWikiApi: "step": "start", "userid": int(user_id) if user_id is not None else None, "useraction": user_action, + "botid": bot_id, "tokens": int(tokens) if tokens is not None else None, - "extractlines": int(extractlines) if extractlines is not None else None, + "pointusage": int(point_usage) if point_usage is not None else 0, "format": "json", "formatversion": "2", } @@ -316,11 +304,10 @@ class MediaWikiApi: else: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) - point_cost = int(data["aitoolboxbot"]["reportusage"]["pointcost"] or 0) - return ChatCompleteReportUsageResponse(point_cost=point_cost, - transaction_id=data["aitoolboxbot"]["reportusage"]["transactionid"]) - - async def ai_toolbox_end_transaction(self, transaction_id: str, tokens: Optional[int] = None): + return ChatCompleteReportUsageResponse(transaction_id=data["aitoolboxbot"]["reportusage"]["transactionid"]) + + + async def ai_toolbox_end_transaction(self, transaction_id: str, point_usage: Optional[int] = None, tokens: Optional[int] = None): await self.refresh_login() try: @@ -330,7 +317,7 @@ class MediaWikiApi: "method": "reportusage", "step": "end", "transactionid": transaction_id, - "tokens": tokens, + "pointusage": int(point_usage) if point_usage is not None else 0, "format": "json", "formatversion": "2", } @@ -344,7 +331,8 @@ class MediaWikiApi: return data["aitoolboxbot"]["reportusage"]["success"] except Exception as e: print(e, file=sys.stderr) - + + async def ai_toolbox_cancel_transaction(self, transaction_id: str, error: Optional[str] = None): await self.refresh_login() diff --git a/service/openai_api.py b/service/openai_api.py index 97a7f62..1b0af29 100644 --- a/service/openai_api.py +++ b/service/openai_api.py @@ -4,7 +4,7 @@ import json from typing import Callable, Optional, TypedDict import aiohttp -from lib.config import Config +from libs.config import Config import numpy as np from aiohttp_sse_client2 import client as sse_client @@ -65,11 +65,18 @@ class OpenAIApi: def get_url(self, method: str): if self.api_type == "azure": - deployments = Config.get(f"chatcomplete.{self.api_id}.deployments") if method == "chat/completions": - return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method + deployment = Config.get(f"chatcomplete.{self.api_id}.deployment_chatcomplete") + if deployment is None: + raise Exception("deployment for chatcomplete is not set") + + return self.api_url + "/openai/deployments/" + deployment + "/" + method elif method == "embeddings": - return self.api_url + "/openai/deployments/" + deployments["embedding"] + "/" + method + deployment = Config.get(f"chatcomplete.{self.api_id}.deployment_embedding") + if deployment is None: + raise Exception("deployment for embedding is not set") + + return self.api_url + "/openai/deployments/" + deployment + "/" + method else: return self.api_url + "/v1/" + method @@ -175,7 +182,7 @@ class OpenAIApi: return messageList - async def chat_complete(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], user = None): + async def chat_complete(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], user = None): messageList = await self.make_message_list(question, system_prompt, conversation) url = self.get_url("chat/completions") @@ -189,7 +196,7 @@ class OpenAIApi: if self.api_type == "azure": params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION else: - post_data["model"] = "gpt-3.5-turbo" + post_data["model"] = model post_data = {k: v for k, v in post_data.items() if v is not None} @@ -218,10 +225,13 @@ class OpenAIApi: message_tokens=message_tokens, total_tokens=total_tokens, finish_reason=finish_reason) + else: + print(data) + raise Exception("Invalid response from chat complete api") return None - async def chat_complete_stream(self, question: str, system_prompt: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None): + async def chat_complete_stream(self, question: str, system_prompt: str, model: str, conversation: list[ChatCompleteMessageLog] = [], on_message = None, user = None): tiktoken = await TikTokenService.create() messageList = await self.make_message_list(question, system_prompt, conversation) @@ -247,7 +257,7 @@ class OpenAIApi: if self.api_type == "azure": params["api-version"] = AZURE_CHATCOMPLETE_API_VERSION else: - post_data["model"] = "gpt-3.5-turbo" + post_data["model"] = model post_data = {k: v for k, v in post_data.items() if v is not None} diff --git a/service/text_embedding.py b/service/text_embedding.py index cf79bb5..7f49f06 100644 --- a/service/text_embedding.py +++ b/service/text_embedding.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import deepcopy import time -from lib.config import Config +from libs.config import Config import asyncio import random import threading @@ -22,14 +22,18 @@ class Text2VecEmbeddingQueueTaskInfo(TypedDict): class Text2VecEmbeddingQueue: def __init__(self, model: str) -> None: self.model_name = model + self.embedding_model: SentenceModel = None - self.embedding_model = SentenceModel(self.model_name) self.task_map: dict[int, Text2VecEmbeddingQueueTaskInfo] = {} self.task_list: list[Text2VecEmbeddingQueueTaskInfo] = [] self.lock = threading.Lock() self.thread: Optional[threading.Thread] = None self.running = False + + + def post_init(self): + self.embedding_model = SentenceModel(self.model_name) async def get_embeddings(self, text: str): @@ -54,6 +58,7 @@ class Text2VecEmbeddingQueue: await asyncio.sleep(0.01) + def pop_task(self, task_id): with self.lock: if task_id in self.task_map: @@ -64,6 +69,7 @@ class Text2VecEmbeddingQueue: return None + def run(self): running = True last_task_time = None @@ -88,15 +94,22 @@ class Text2VecEmbeddingQueue: else: time.sleep(0.01) + def start_queue(self): if not self.running: self.running = True self.thread = threading.Thread(target=self.run) self.thread.start() + class TextEmbeddingService: instance = None + def __init__(self): + self.tiktoken: TikTokenService = None + self.text2vec_queue: Text2VecEmbeddingQueue = None + self.openai_api: OpenAIApi = None + @staticmethod async def create() -> TextEmbeddingService: if TextEmbeddingService.instance is None: @@ -111,21 +124,27 @@ class TextEmbeddingService: if self.embedding_type == "text2vec": embedding_model = Config.get("embedding.embedding_model", "shibing624/text2vec-base-chinese") - self.embedding_queue = Text2VecEmbeddingQueue(model=embedding_model) + self.text2vec_queue = Text2VecEmbeddingQueue(model=embedding_model) elif self.embedding_type == "openai": - self.openai_api: OpenAIApi = await OpenAIApi.create() + api_id = Config.get("embedding.api_id") + self.openai_api: OpenAIApi = await OpenAIApi.create(api_id) - await loop.run_in_executor(None, self.embedding_queue.init) + await loop.run_in_executor(None, self.text2vec_queue.post_init) async def get_text2vec_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): + total_token_usage = 0 for index, doc in enumerate(doc_list): text = doc["text"] - embedding = await self.embedding_queue.get_embeddings(text) + embedding = await self.text2vec_queue.get_embeddings(text) doc["embedding"] = embedding + total_token_usage += await self.tiktoken.get_tokens(text) + if on_index_progress is not None: await on_index_progress(index, len(doc_list)) + return (doc_list, total_token_usage) + async def get_embeddings(self, doc_list: list, on_index_progress: Optional[Callable[[int, int], None]] = None): res_doc_list = deepcopy(doc_list) diff --git a/test/base.py b/test/base.py index 74c4dae..7f75967 100644 --- a/test/base.py +++ b/test/base.py @@ -4,7 +4,7 @@ import pathlib root_path = pathlib.Path(__file__).parent.parent sys.path.append(".") -from lib.config import Config +from libs.config import Config Config.load_config(str(root_path) + "/config.toml") Config.set("server.debug", True) \ No newline at end of file diff --git a/test/chatcomplete.py b/test/chatcomplete.py index c00360e..5ea21e6 100644 --- a/test/chatcomplete.py +++ b/test/chatcomplete.py @@ -16,7 +16,7 @@ async def main(): try: prepare_res = await chat_complete.prepare_chat_complete(question, None, 1, question_tokens, { - "distance_limit": 0.6, + "distance_limit": 20, "limit": 10 }) diff --git a/test/embedding_search.py b/test/embedding_search.py index 969a247..3404bbd 100644 --- a/test/embedding_search.py +++ b/test/embedding_search.py @@ -14,14 +14,14 @@ async def main(): print("\r索引进度:%.1f%%" % (current / length * 100), end="", flush=True) print("") - await embedding_search.update_page_index(on_index_progress) + #await embedding_search.update_page_index(on_index_progress) print("") while True: query = input("请输入要搜索的问题 (.exit 退出):") if query == ".exit": break - res, token_usage = await embedding_search.search(query, 5) + res, token_usage = await embedding_search.search(query, 5, False, 0.1) total_length = 0 if res: for one in res: diff --git a/utils/config.py b/utils/config.py index e30f87a..60ea58c 100644 --- a/utils/config.py +++ b/utils/config.py @@ -1,5 +1,5 @@ import time -from lib.config import Config +from libs.config import Config def get_prompt(name: str, type: str, params: dict = {}): sys_params = { diff --git a/utils/local.py b/utils/local.py index 3c69ea0..079867a 100644 --- a/utils/local.py +++ b/utils/local.py @@ -1,5 +1,5 @@ import asyncio -from lib.noawait import NoAwaitPool +from libs.noawait import NoAwaitPool debug = False loop = asyncio.new_event_loop() diff --git a/utils/web.py b/utils/web.py index d7109e1..e3cb506 100644 --- a/utils/web.py +++ b/utils/web.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Dict from aiohttp import web import jwt import uuid -from lib.config import Config +from libs.config import Config ParamRule = Dict[str, Any] @@ -15,7 +15,7 @@ class ParamInvalidException(web.HTTPBadRequest): self.param_list = param_list self.rules = rules param_list_str = "'" + ("', '".join(param_list)) + "'" - super().__init__(f"Param invalid: {param_list_str}") + super().__init__(reason=f"Param invalid: {param_list_str}") async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None): params: dict[str, Any] = {}