From 8ba74fa95221f2fe1d3b8a6de803df84e77bed5d Mon Sep 17 00:00:00 2001 From: Lex Lim Date: Tue, 27 Jun 2023 10:56:39 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E9=85=8D=E7=BD=AE=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 +- .../chat_complete/bot_persona_category.py | 43 +++++++++++ api/model/embedding_search/page_index.py | 5 +- api/model/embedding_search/title_index.py | 6 +- config-example.py | 71 ------------------ config.example.toml | 75 +++++++++++++++++++ config.py | 56 ++++++++++++++ install.py | 9 ++- main.py | 23 ++++-- requirements.txt | 5 +- service/bert_embedding.py | 2 +- service/chat_complete.py | 27 ++++--- service/database.py | 20 +++-- service/mediawiki_api.py | 33 ++++---- service/openai_api.py | 40 +++++----- test/base.py | 9 ++- utils/config.py | 9 ++- utils/web.py | 7 +- 18 files changed, 291 insertions(+), 151 deletions(-) create mode 100644 api/model/chat_complete/bot_persona_category.py delete mode 100644 config-example.py create mode 100644 config.example.toml create mode 100644 config.py diff --git a/.gitignore b/.gitignore index 9389032..86cba70 100644 --- a/.gitignore +++ b/.gitignore @@ -141,4 +141,4 @@ dmypy.json # Cython debug symbols cython_debug/ -/config.py \ No newline at end of file +/config.toml \ No newline at end of file diff --git a/api/model/chat_complete/bot_persona_category.py b/api/model/chat_complete/bot_persona_category.py new file mode 100644 index 0000000..3bbdc9f --- /dev/null +++ b/api/model/chat_complete/bot_persona_category.py @@ -0,0 +1,43 @@ +from __future__ import annotations +import sqlalchemy +from api.model.base import BaseHelper, BaseModel + +import sqlalchemy +from sqlalchemy import select, update +from sqlalchemy.orm import mapped_column, relationship, Mapped + +class BotPersonaCategoryModel(BaseModel): + __tablename__ = "chat_complete_bot_persona_category" + + id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) + description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) + icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) + font_icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) + color: Mapped[str] = mapped_column(sqlalchemy.String(10), nullable=True) + order: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) + +class BotPersonaHelper(BaseHelper): + async def add(self, obj: BotPersonaCategoryModel): + self.session.add(obj) + await self.session.commit() + await self.session.refresh(obj) + return obj + + async def update(self, obj: BotPersonaCategoryModel): + obj = await self.session.merge(obj) + await self.session.commit() + return obj + + async def remove(self, id: int): + stmt = sqlalchemy.delete(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id) + await self.session.execute(stmt) + await self.session.commit() + + async def get_list(self): + stmt = select(BotPersonaCategoryModel) + return await self.session.scalars(stmt) + + async def find_by_id(self, id: int): + stmt = select(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id) + return await self.session.scalar(stmt) \ No newline at end of file diff --git a/api/model/embedding_search/page_index.py b/api/model/embedding_search/page_index.py index 541f375..16bdaf7 100644 --- a/api/model/embedding_search/page_index.py +++ b/api/model/embedding_search/page_index.py @@ -4,9 +4,9 @@ from typing import Optional, Type import asyncpg from api.model.base import BaseModel -import config import numpy as np import sqlalchemy +from config import Config from sqlalchemy import Index, select, update, delete, Select from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.ext.asyncio import AsyncSession @@ -17,6 +17,7 @@ from service.database import DatabaseService page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {} +embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int) class AbstractPageIndexModel(BaseModel): __abstract__ = True @@ -26,7 +27,7 @@ class AbstractPageIndexModel(BaseModel): ) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True) - embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)) + embedding: Mapped[np.ndarray] = mapped_column(Vector(embedding_vector_size)) text: Mapped[str] = mapped_column(sqlalchemy.Text) text_len: Mapped[int] = mapped_column(sqlalchemy.Integer) markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True) diff --git a/api/model/embedding_search/title_index.py b/api/model/embedding_search/title_index.py index b79aa0d..568c388 100644 --- a/api/model/embedding_search/title_index.py +++ b/api/model/embedding_search/title_index.py @@ -8,10 +8,12 @@ import sqlalchemy from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer from sqlalchemy.ext.asyncio import AsyncEngine -import config +from config import Config from api.model.base import BaseHelper, BaseModel from service.database import DatabaseService +embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int) + class TitleIndexModel(BaseModel): __tablename__ = "embedding_search_title_index" @@ -22,7 +24,7 @@ class TitleIndexModel(BaseModel): collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) indexed_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer) - embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE), nullable=True)) + embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(embedding_vector_size), nullable=True)) embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding, postgresql_using='ivfflat', diff --git a/config-example.py b/config-example.py deleted file mode 100644 index 6aa04c6..0000000 --- a/config-example.py +++ /dev/null @@ -1,71 +0,0 @@ -PORT = 8144 -HOST = "www.isekai.cn" -MW_API = "http://dev.isekai.cn/api.php" - -DEBUG = True - -DATABASE = { - "host": "127.0.0.1", - "database": "isekai_toolkit", - "user": "", - "password": "", - "port": "5432", -} - -EMBEDDING_VECTOR_SIZE = 1536 - -OPENAI_API_TYPE = "openai" # openai or azure -OPENAI_API = "https://api.openai.com" -OPENAI_TOKEN = "sk-" -OPENAI_API = None -OPENAI_TOKEN = "" -AZURE_OPENAI_ENDPOINT = "https://your-instance.openai.azure.com" -AZURE_OPENAI_KEY = "" -AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME = "" -AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME = "" - -CHATCOMPLETE_MAX_MEMORY_TOKENS = 1024 -CHATCOMPLETE_MAX_INPUT_TOKENS = 768 - -CHATCOMPLETE_OUTPUT_REPLACE = { - "OpenAI": "オーペンエーアイ", - "ChatGPT": "チャットジーピーティー", - "GPT": "ジーピーティー", - "上下文": "消息", - "AI": "虛擬人物程序", - "语言模型": "虛擬人物程序", - "人工智能程序": "虛擬人物程序", - "語言模型": "虛擬人物程序", - "人工智能程式": "虛擬人物程序", -} - -CHATCOMPLETE_BOT_NAME = "寫作助手" - -PROMPTS = { - "chat": { - "system_prompt": "You are a writer. You are the writing assistant of the '異世界百科'. Your name is '{bot_name}'. You need to help users complete the characters and settings in their novel.", - }, - "title": { - "system_prompt": "You are a writing assistant, you only need to assist in writing, do not express your opinion.", - "prompt": "Write a short title in Chinese for the following conversation, don't use quotes:\n\n{content}" - }, - "suggestions": { - "prompt": "根據下面的對話,提出幾個問題:\n\n{content}" - }, - "summary": { - "system_prompt": "You are a writing assistant, you only need to assist in writing, do not express your opinion. Output in Chinese.", - "prompt": "為“{bot_name}”概括下面的聊天記錄,排除不重要的對話,不要表明自己的意見,儘量簡潔。使用中文輸出,“User”是同一個人。\n\n{content}" - }, - "extracted_doc": { - "prompt": "Here are some relevant informations:\n\n{content}" - } -} - -REQUEST_PROXY = "http://127.0.0.1:7890" - -AUTH_TOKENS = { - "isekaiwiki": "sk-123456" -} - -MW_BOT_LOGIN_USERNAME = "Hyperzlib@ChatComplete" -MW_BOT_LOGIN_PASSWORD = "" \ No newline at end of file diff --git a/config.example.toml b/config.example.toml new file mode 100644 index 0000000..27de626 --- /dev/null +++ b/config.example.toml @@ -0,0 +1,75 @@ +[server] +port = 8144 +host = "www.isekai.cn" +debug = false + +[database] +host = "" +port = 5432 +user = "" +password = "" +database = "" + +[mediawiki] +api_endpoint = "http://dev.isekai.cn/api.php" +bot_username = "" +bot_password = "" + +[request] +proxy = "" + +[authorization] +isekaiwiki = "sk-123456" + +[chatcomplete] +enabled = true + +bot_name = "寫作助手" + +embedding_vector_size = 1536 +embedding_type = "openai" + +max_memory_tokens = 1280 +max_input_tokens = 128 + +api_type = "azure" + +[chatcomplete.openai] +api_endpoint = "https://api.openai.com" +key = "sk-" + +[chatcomplete.azure] +api_endpoint = "https://your-instance.openai.azure.com" +key = "" + +[chatcomplete.azure.deployments] +chatcomplete = "" +embedding = "" + +[chatcomplete.bert] +model = "bert-base-chinese" + +[chatcomplete.replace.output] +"OpenAI" = "オーペンエーアイ" +"ChatGPT" = "チャットジーピーティー" +"GPT" = "ジーピーティー" +"上下文" = "消息" +"AI" = "虛擬人物程序" +"语言模型" = "虛擬人物程序" +"人工智能程序" = "虛擬人物程序" +"語言模型" = "虛擬人物程序" +"人工智能程式" = "虛擬人物程序" + +[chatcomplete.prompts.chat] +system = "You are a writer. You are the writing assistant of the '異世界百科'. Your name is '{bot_name}'. You need to help users complete the characters and settings in their novel." + +[chatcomplete.prompts.make_title] +system = "You are a writing assistant, you only need to assist in writing, do not express your opinion." +prompt = "Write a short title in Chinese for the following conversation, don't use quotes:\n\n{content}" + +[chatcomplete.prompts.make_summary] +system = "You are a writing assistant, you only need to assist in writing, do not express your opinion. Reply in Chinese." +prompt = "為“{bot_name}”概括下面的聊天記錄,排除不重要的對話,不要表明自己的意見,儘量簡潔。使用中文輸出,對話中的“User”是同一個人。\n\n{content}" + +[chatcomplete.prompts.extracted_doc] +prompt = "Here are some relevant informations:\n\n{content}" \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..5d932d4 --- /dev/null +++ b/config.py @@ -0,0 +1,56 @@ +from __future__ import annotations +from typing import TypeVar +import toml + +class Config: + values: dict = {} + + @staticmethod + def load_config(file): + with open(file, "r") as f: + Config.values = toml.load(f) + + @staticmethod + def get(key: str, default=None, type=None, empty_is_none=False): + key_path = key.split(".") + value = Config.values + for k in key_path: + if k in value: + value = value[k] + else: + return default + + if empty_is_none and value == "": + return None + + if type == bool: + if isinstance(value, bool): + return value + elif isinstance(value, int) or isinstance(value, float): + return value != 0 + else: + return str(value).lower() in ("yes", "true", "1") + elif type == int: + return int(value) + elif type == float: + return float(value) + elif type == str: + return str(value) + elif type == list: + if not isinstance(value, list): + return [] + elif type == dict: + if not isinstance(value, dict): + return {} + else: + return value + + @staticmethod + def set(key: str, value): + key_path = key.split(".") + obj = Config.values + for k in key_path[:-1]: + if k not in obj: + obj[k] = {} + obj = obj[k] + obj[key_path[-1]] = value diff --git a/install.py b/install.py index 407f975..8abef40 100644 --- a/install.py +++ b/install.py @@ -1,7 +1,7 @@ import asyncio import asyncpg -import config import os +from config import Config conn = None @@ -9,7 +9,8 @@ class Install: dbi: asyncpg.Connection async def run(self): - self.dbi = await asyncpg.connect(**config.DATABASE) + db_config = Config.get("database") + self.dbi = await asyncpg.connect(db_config) args = os.sys.argv if "--force" in args: await self.drop_table() @@ -21,6 +22,8 @@ class Install: print("Table dropped") async def create_table(self): + embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int) + await self.dbi.execute(""" CREATE TABLE embedding_search_title_index ( id SERIAL PRIMARY KEY, @@ -29,7 +32,7 @@ class Install: rev_id INT8 NOT NULL, embedding VECTOR(%d) NOT NULL ); - """ % (config.EMBEDDING_VECTOR_SIZE)) + """ % (embedding_vector_size)) await self.dbi.execute("CREATE INDEX embedding_search_title_index_embedding_idx ON embedding_search_title_index USING ivfflat (embedding vector_cosine_ops);") print("Table created") diff --git a/main.py b/main.py index 562e183..e01f52a 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,8 @@ from local import loop, noawait from aiohttp import web -import config +from config import Config +import toml import api.route import utils.web from service.database import DatabaseService @@ -22,9 +23,13 @@ async def index(request: web.Request): async def init_mw_api(app: web.Application): mw_api = MediaWikiApi.create() - if config.MW_BOT_LOGIN_USERNAME and config.MW_BOT_LOGIN_PASSWORD: + + bot_username = Config.get("mediawiki.bot_username", "", str) + bot_password = Config.get("mediawiki.bot_password", "", str) + + if bot_username and bot_password: try: - await mw_api.robot_login(config.MW_BOT_LOGIN_USERNAME, config.MW_BOT_LOGIN_PASSWORD) + await mw_api.robot_login(bot_username, bot_password) except Exception as e: print("Cannot login to Robot account, please check config.") @@ -47,15 +52,17 @@ async def stop_noawait_pool(app: web.Application): await noawait.end() if __name__ == '__main__': + Config.load_config("config.toml") + app = web.Application() - if config.DATABASE: + if Config.get("database.host"): app.on_startup.append(init_database) - if config.MW_API: + if Config.get("mediawiki.api_endpoint"): app.on_startup.append(init_mw_api) - if config.OPENAI_TOKEN: + if Config.get("chatcomplete.enabled"): app.on_startup.append(init_tiktoken) app.on_shutdown.append(stop_noawait_pool) @@ -63,4 +70,6 @@ if __name__ == '__main__': app.router.add_route('*', '/', index) api.route.init(app) - web.run_app(app, host='0.0.0.0', port=config.PORT, loop=loop) \ No newline at end of file + server_port = Config.get("port", 8144, int) + + web.run_app(app, host='0.0.0.0', port=server_port, loop=loop) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f5eb69d..6c12268 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ PyJWT==2.6.0 asyncpg-stubs==0.27.0 sqlalchemy==2.0.17 aiohttp-sse-client2==0.3.0 -OpenCC==1.1.6 +OpenCC==1.1.1 event-emitter-asyncio==1.0.4 -tiktoken-async==0.3.2 \ No newline at end of file +tiktoken-async==0.3.2 +toml==0.10.2 \ No newline at end of file diff --git a/service/bert_embedding.py b/service/bert_embedding.py index 56450a4..c5cdece 100644 --- a/service/bert_embedding.py +++ b/service/bert_embedding.py @@ -1,6 +1,6 @@ from __future__ import annotations import time -import config +from config import Config import asyncio import random import threading diff --git a/service/chat_complete.py b/service/chat_complete.py index 48aca2e..f5012c7 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -11,7 +11,7 @@ from api.model.chat_complete.conversation import ( import sys from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel -import config +from config import Config import utils.config, utils.web from aiohttp import web @@ -122,9 +122,10 @@ class ChatCompleteService: else: self.question_tokens = question_tokens + max_input_tokens = Config.get("chatcomplete.max_input_tokens", 768, int) if ( - len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS - and self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS + len(question) * 4 > max_input_tokens + and self.question_tokens > max_input_tokens ): # If the question is too long, we need to truncate it raise web.HTTPRequestEntityTooLarge() @@ -174,7 +175,8 @@ class ChatCompleteService: ) # If the conversation is too long, we need to make a summary - if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS: + max_memory_tokens = Config.get("chatcomplete.max_memory_tokens", 1280, int) + if self.conversation_chunk.tokens > max_memory_tokens: summary, tokens = await self.make_summary( self.conversation_chunk.message_data ) @@ -261,7 +263,7 @@ class ChatCompleteService: ) message_log.append({"role": "user", "content": doc_prompt}) - system_prompt = utils.config.get_prompt("chat", "system_prompt") + system_prompt = utils.config.get_prompt("chat", "system") # Start chat complete if on_message is not None: @@ -348,22 +350,23 @@ class ChatCompleteService: async def make_summary(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] + bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str) for message_data in message_log_list: if message_data["role"] == "summary": chat_log.append(message_data["content"]) elif message_data["role"] == "assistant": chat_log.append( - f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}' + f'{bot_name}: {message_data["content"]}' ) else: chat_log.append(f'User: {message_data["content"]}') chat_log_str = "\n".join(chat_log) - summary_system_prompt = utils.config.get_prompt("summary", "system_prompt") + summary_system_prompt = utils.config.get_prompt("make_summary", "system") summary_prompt = utils.config.get_prompt( - "summary", "prompt", {"content": chat_log_str} + "make_summary", "prompt", {"content": chat_log_str} ) response = await self.openai_api.chat_complete( @@ -374,19 +377,21 @@ class ChatCompleteService: async def make_title(self, message_log_list: list) -> tuple[str, int]: chat_log: list[str] = [] + bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str) + for message_data in message_log_list: if message_data["role"] == "assistant": chat_log.append( - f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}' + f'{bot_name}: {message_data["content"]}' ) elif message_data["role"] == "user": chat_log.append(f'User: {message_data["content"]}') chat_log_str = "\n".join(chat_log) - title_system_prompt = utils.config.get_prompt("title", "system_prompt") + title_system_prompt = utils.config.get_prompt("make_title", "system") title_prompt = utils.config.get_prompt( - "title", "prompt", {"content": chat_log_str} + "make_title", "prompt", {"content": chat_log_str} ) response = await self.openai_api.chat_complete( diff --git a/service/database.py b/service/database.py index 69c623d..22e9d93 100644 --- a/service/database.py +++ b/service/database.py @@ -4,15 +4,16 @@ from urllib.parse import quote_plus from aiohttp import web import asyncpg from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine -import config +from config import Config def get_dsn(): + db_conf = Config.get("database") return "postgresql+asyncpg://%s:%s@%s:%s/%s" % ( - quote_plus(config.DATABASE["user"]), - quote_plus(config.DATABASE["password"]), - config.DATABASE["host"], - config.DATABASE["port"], - quote_plus(config.DATABASE["database"])) + quote_plus(db_conf["user"]), + quote_plus(db_conf["password"]), + db_conf["host"], + db_conf["port"], + quote_plus(db_conf["database"])) class DatabaseService: instance = None @@ -38,9 +39,12 @@ class DatabaseService: self.create_session: async_sessionmaker[AsyncSession] = None async def init(self): + db_conf = Config.get("database") + debug_mode = Config.get("debug", False, bool) + loop = local.loop - self.pool = await asyncpg.create_pool(**config.DATABASE, loop=loop) + self.pool = await asyncpg.create_pool(**db_conf, loop=loop) - engine = create_async_engine(get_dsn(), echo=config.DEBUG) + engine = create_async_engine(get_dsn(), echo=debug_mode) self.engine = engine self.create_session = async_sessionmaker(engine, expire_on_commit=False) \ No newline at end of file diff --git a/service/mediawiki_api.py b/service/mediawiki_api.py index 61b913a..93678f9 100644 --- a/service/mediawiki_api.py +++ b/service/mediawiki_api.py @@ -4,7 +4,10 @@ import sys import time from typing import Optional, TypedDict import aiohttp -import config +from config import Config + +mw_api = Config.get("mw.api_endpoint", "https://www.isekai.cn/api.php") +request_proxy = Config.get("request.proxy", type=str, empty_is_none=True) class MediaWikiApiException(Exception): def __init__(self, info: str, code: Optional[str] = None) -> None: @@ -53,7 +56,7 @@ class MediaWikiApi: @staticmethod def create(): if MediaWikiApi.instance is None: - MediaWikiApi.instance = MediaWikiApi(config.MW_API) + MediaWikiApi.instance = MediaWikiApi(mw_api) return MediaWikiApi.instance @@ -74,7 +77,7 @@ class MediaWikiApi: "titles": title, "inprop": "url" } - async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: + async with session.get(self.api_url, params=params, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) @@ -96,7 +99,7 @@ class MediaWikiApi: "disabletoc": "true", "disablelimitreport": "true", } - async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: + async with session.get(self.api_url, params=params, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) @@ -112,7 +115,7 @@ class MediaWikiApi: "meta": "siteinfo|userinfo", "siprop": "general" } - async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: + async with session.get(self.api_url, params=params, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) @@ -145,7 +148,7 @@ class MediaWikiApi: if start_title is not None: params["apfrom"] = start_title - async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: + async with session.get(self.api_url, params=params, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) @@ -169,7 +172,7 @@ class MediaWikiApi: "formatversion": "2", "meta": "userinfo" } - async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: + async with session.get(self.api_url, params=params, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) @@ -185,7 +188,7 @@ class MediaWikiApi: "meta": "tokens", "type": token_type } - async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: + async with session.get(self.api_url, params=params, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) @@ -207,7 +210,7 @@ class MediaWikiApi: "lgpassword": password, "lgtoken": token, } - async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: + async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) @@ -239,7 +242,7 @@ class MediaWikiApi: "namespace": 0, "format": "json", } - async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: + async with session.get(self.api_url, params=params, proxy=request_proxy) as resp: data = await resp.json() return data[1] @@ -254,7 +257,7 @@ class MediaWikiApi: "format": "json", "formatversion": "2", } - async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: + async with session.get(self.api_url, params=params, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: if data["error"]["code"] == "user-not-found": @@ -281,7 +284,7 @@ class MediaWikiApi: } # Filter out None values post_data = {k: v for k, v in post_data.items() if v is not None} - async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: + async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: print(data) @@ -307,7 +310,7 @@ class MediaWikiApi: } # Filter out None values post_data = {k: v for k, v in post_data.items() if v is not None} - async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: + async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: if data["error"]["code"] == "noenoughpoints": @@ -335,7 +338,7 @@ class MediaWikiApi: } # Filter out None values post_data = {k: v for k, v in post_data.items() if v is not None} - async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: + async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) @@ -360,7 +363,7 @@ class MediaWikiApi: } # Filter out None values post_data = {k: v for k, v in post_data.items() if v is not None} - async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: + async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp: data = await resp.json() if "error" in data: raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) diff --git a/service/openai_api.py b/service/openai_api.py index 8d62088..04edf82 100644 --- a/service/openai_api.py +++ b/service/openai_api.py @@ -4,7 +4,7 @@ import json from typing import Callable, Optional, TypedDict import aiohttp -import config +from config import Config import numpy as np from aiohttp_sse_client2 import client as sse_client @@ -21,21 +21,24 @@ class ChatCompleteResponse(TypedDict): total_tokens: int finish_reason: str +api_type = Config.get("chatcomplete.api_type", "openai", str) +request_proxy = Config.get("chatcomplete.request_proxy", type=str, empty_is_none=True) + class OpenAIApi: @staticmethod def create(): return OpenAIApi() def __init__(self): - if config.OPENAI_API_TYPE == "azure": - self.api_url = config.AZURE_OPENAI_ENDPOINT - self.api_key = config.AZURE_OPENAI_KEY + if api_type == "azure": + self.api_url = Config.get("chatcomplete.azure.api_endpoint", type=str) + self.api_key = Config.get("chatcomplete.azure.key", type=str) else: - self.api_url = config.OPENAI_API or "https://api.openai.com" - self.api_key = config.OPENAI_TOKEN + self.api_url = Config.get("chatcomplete.openai.api_endpoint", type=str) + self.api_key = Config.get("chatcomplete.openai.key", type=str) def build_header(self): - if config.OPENAI_API_TYPE == "azure": + if api_type == "azure": return { "content-type": "application/json", "accept": "application/json", @@ -49,11 +52,12 @@ class OpenAIApi: } def get_url(self, method: str): - if config.OPENAI_API_TYPE == "azure": + if api_type == "azure": + deployments = Config.get("chatcomplete.azure.deployments", type=dict) if method == "chat/completions": - return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME + "/" + method + return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method elif method == "embeddings": - return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME + "/" + method + return self.api_url + "/openai/deployments/" + deployments["embedding"] + "/" + method else: return self.api_url + "/v1/" + method @@ -84,13 +88,13 @@ class OpenAIApi: "input": text_list, } - if config.OPENAI_API_TYPE == "azure": + if api_type == "azure": params["api-version"] = "2023-05-15" else: post_data["model"] = "text-embedding-ada-002" - if config.OPENAI_API_TYPE == "azure": + if api_type == "azure": # Azure api does not support batch for index, text in enumerate(text_list): retry_num = 0 @@ -102,7 +106,7 @@ class OpenAIApi: params=params, json={"input": text}, timeout=30, - proxy=config.REQUEST_PROXY) as resp: + proxy=request_proxy) as resp: data = await resp.json() @@ -134,7 +138,7 @@ class OpenAIApi: params=params, json=post_data, timeout=30, - proxy=config.REQUEST_PROXY) as resp: + proxy=request_proxy) as resp: data = await resp.json() @@ -182,7 +186,7 @@ class OpenAIApi: "user": user, } - if config.OPENAI_API_TYPE == "azure": + if api_type == "azure": params["api-version"] = "2023-05-15" else: post_data["model"] = "gpt-3.5-turbo" @@ -195,7 +199,7 @@ class OpenAIApi: params=params, json=post_data, timeout=30, - proxy=config.REQUEST_PROXY) as resp: + proxy=api_type) as resp: data = await resp.json() @@ -240,7 +244,7 @@ class OpenAIApi: "top_p": 0.95 } - if config.OPENAI_API_TYPE == "azure": + if api_type == "azure": params["api-version"] = "2023-05-15" else: post_data["model"] = "gpt-3.5-turbo" @@ -258,7 +262,7 @@ class OpenAIApi: headers=self.build_header(), params=params, json=post_data, - proxy=config.REQUEST_PROXY + proxy=request_proxy ) as session: async for event in session: """ diff --git a/test/base.py b/test/base.py index a184749..6d40567 100644 --- a/test/base.py +++ b/test/base.py @@ -1,7 +1,10 @@ import sys import pathlib -sys.path.append(str(pathlib.Path(__file__).parent.parent)) +root_path = pathlib.Path(__file__).parent.parent +sys.path.append(root_path) -import config -config.DEBUG = True \ No newline at end of file +from config import Config + +Config.load_config(root_path + "/config.toml") +Config.set("debug", True) \ No newline at end of file diff --git a/utils/config.py b/utils/config.py index b26b835..caf6615 100644 --- a/utils/config.py +++ b/utils/config.py @@ -1,11 +1,12 @@ -import config +from config import Config def get_prompt(name: str, type: str, params: dict = {}): sys_params = { - "bot_name": config.CHATCOMPLETE_BOT_NAME + "bot_name": Config.get("chatcomplete.bot_name", "ChatGPT"), } - if name in config.PROMPTS and type in config.PROMPTS[name]: - prompt = config.PROMPTS[name][type] + prompts = Config.get("chatcomplete.prompts") + if name in prompts and type in prompts[name]: + prompt: str = prompts[name][type] for key in sys_params: prompt = prompt.replace("{" + key + "}", sys_params[key]) diff --git a/utils/web.py b/utils/web.py index 163f333..672ebd3 100644 --- a/utils/web.py +++ b/utils/web.py @@ -4,8 +4,8 @@ import json from typing import Any, Optional, Dict from aiohttp import web import jwt -import config import uuid +from config import Config ParamRule = Dict[str, Any] @@ -119,6 +119,7 @@ def token_auth(f): @wraps(f) def decorated_function(*args, **kwargs): async def async_wrapper(*args, **kwargs): + auth_tokens: dict = Config.get("authorization") request: web.Request = args[0] jwt_token = None @@ -141,7 +142,7 @@ def token_auth(f): jwt_token = token if sk_token is not None: - if sk_token not in config.AUTH_TOKENS.values(): + if sk_token not in auth_tokens.values(): return await api_response(status=-1, error={ "code": "token-invalid", "target": "token_id", @@ -168,7 +169,7 @@ def token_auth(f): # Check jwt try: - data = jwt.decode(jwt_token, config.AUTH_TOKENS[key_id], algorithms=['HS256', 'HS384', 'HS512']) + data = jwt.decode(jwt_token, auth_tokens[key_id], algorithms=['HS256', 'HS384', 'HS512']) if "sub" not in data: return await api_response(status=-1, error={ "code": "token-invalid",