更新配置文件库

master
落雨楓 2 years ago
parent 74e8f9c7b0
commit 8ba74fa952

2
.gitignore vendored

@ -141,4 +141,4 @@ dmypy.json
# Cython debug symbols # Cython debug symbols
cython_debug/ cython_debug/
/config.py /config.toml

@ -0,0 +1,43 @@
from __future__ import annotations
import sqlalchemy
from api.model.base import BaseHelper, BaseModel
import sqlalchemy
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, Mapped
class BotPersonaCategoryModel(BaseModel):
__tablename__ = "chat_complete_bot_persona_category"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
font_icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
color: Mapped[str] = mapped_column(sqlalchemy.String(10), nullable=True)
order: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
class BotPersonaHelper(BaseHelper):
async def add(self, obj: BotPersonaCategoryModel):
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj
async def update(self, obj: BotPersonaCategoryModel):
obj = await self.session.merge(obj)
await self.session.commit()
return obj
async def remove(self, id: int):
stmt = sqlalchemy.delete(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id)
await self.session.execute(stmt)
await self.session.commit()
async def get_list(self):
stmt = select(BotPersonaCategoryModel)
return await self.session.scalars(stmt)
async def find_by_id(self, id: int):
stmt = select(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id)
return await self.session.scalar(stmt)

@ -4,9 +4,9 @@ from typing import Optional, Type
import asyncpg import asyncpg
from api.model.base import BaseModel from api.model.base import BaseModel
import config
import numpy as np import numpy as np
import sqlalchemy import sqlalchemy
from config import Config
from sqlalchemy import Index, select, update, delete, Select from sqlalchemy import Index, select, update, delete, Select
from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.orm import mapped_column, Mapped
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -17,6 +17,7 @@ from service.database import DatabaseService
page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {} page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}
embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int)
class AbstractPageIndexModel(BaseModel): class AbstractPageIndexModel(BaseModel):
__abstract__ = True __abstract__ = True
@ -26,7 +27,7 @@ class AbstractPageIndexModel(BaseModel):
) )
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True) sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)) embedding: Mapped[np.ndarray] = mapped_column(Vector(embedding_vector_size))
text: Mapped[str] = mapped_column(sqlalchemy.Text) text: Mapped[str] = mapped_column(sqlalchemy.Text)
text_len: Mapped[int] = mapped_column(sqlalchemy.Integer) text_len: Mapped[int] = mapped_column(sqlalchemy.Integer)
markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True) markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True)

@ -8,10 +8,12 @@ import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer
from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncEngine
import config from config import Config
from api.model.base import BaseHelper, BaseModel from api.model.base import BaseHelper, BaseModel
from service.database import DatabaseService from service.database import DatabaseService
embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int)
class TitleIndexModel(BaseModel): class TitleIndexModel(BaseModel):
__tablename__ = "embedding_search_title_index" __tablename__ = "embedding_search_title_index"
@ -22,7 +24,7 @@ class TitleIndexModel(BaseModel):
collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
indexed_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) indexed_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer) latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE), nullable=True)) embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(embedding_vector_size), nullable=True))
embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding, embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding,
postgresql_using='ivfflat', postgresql_using='ivfflat',

@ -1,71 +0,0 @@
PORT = 8144
HOST = "www.isekai.cn"
MW_API = "http://dev.isekai.cn/api.php"
DEBUG = True
DATABASE = {
"host": "127.0.0.1",
"database": "isekai_toolkit",
"user": "",
"password": "",
"port": "5432",
}
EMBEDDING_VECTOR_SIZE = 1536
OPENAI_API_TYPE = "openai" # openai or azure
OPENAI_API = "https://api.openai.com"
OPENAI_TOKEN = "sk-"
OPENAI_API = None
OPENAI_TOKEN = ""
AZURE_OPENAI_ENDPOINT = "https://your-instance.openai.azure.com"
AZURE_OPENAI_KEY = ""
AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME = ""
AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME = ""
CHATCOMPLETE_MAX_MEMORY_TOKENS = 1024
CHATCOMPLETE_MAX_INPUT_TOKENS = 768
CHATCOMPLETE_OUTPUT_REPLACE = {
"OpenAI": "オーペンエーアイ",
"ChatGPT": "チャットジーピーティー",
"GPT": "ジーピーティー",
"上下文": "消息",
"AI": "虛擬人物程序",
"语言模型": "虛擬人物程序",
"人工智能程序": "虛擬人物程序",
"語言模型": "虛擬人物程序",
"人工智能程式": "虛擬人物程序",
}
CHATCOMPLETE_BOT_NAME = "寫作助手"
PROMPTS = {
"chat": {
"system_prompt": "You are a writer. You are the writing assistant of the '異世界百科'. Your name is '{bot_name}'. You need to help users complete the characters and settings in their novel.",
},
"title": {
"system_prompt": "You are a writing assistant, you only need to assist in writing, do not express your opinion.",
"prompt": "Write a short title in Chinese for the following conversation, don't use quotes:\n\n{content}"
},
"suggestions": {
"prompt": "根據下面的對話,提出幾個問題:\n\n{content}"
},
"summary": {
"system_prompt": "You are a writing assistant, you only need to assist in writing, do not express your opinion. Output in Chinese.",
"prompt": "為“{bot_name}”概括下面的聊天記錄排除不重要的對話不要表明自己的意見儘量簡潔。使用中文輸出“User”是同一個人。\n\n{content}"
},
"extracted_doc": {
"prompt": "Here are some relevant informations:\n\n{content}"
}
}
REQUEST_PROXY = "http://127.0.0.1:7890"
AUTH_TOKENS = {
"isekaiwiki": "sk-123456"
}
MW_BOT_LOGIN_USERNAME = "Hyperzlib@ChatComplete"
MW_BOT_LOGIN_PASSWORD = ""

@ -0,0 +1,75 @@
[server]
port = 8144
host = "www.isekai.cn"
debug = false
[database]
host = ""
port = 5432
user = ""
password = ""
database = ""
[mediawiki]
api_endpoint = "http://dev.isekai.cn/api.php"
bot_username = ""
bot_password = ""
[request]
proxy = ""
[authorization]
isekaiwiki = "sk-123456"
[chatcomplete]
enabled = true
bot_name = "寫作助手"
embedding_vector_size = 1536
embedding_type = "openai"
max_memory_tokens = 1280
max_input_tokens = 128
api_type = "azure"
[chatcomplete.openai]
api_endpoint = "https://api.openai.com"
key = "sk-"
[chatcomplete.azure]
api_endpoint = "https://your-instance.openai.azure.com"
key = ""
[chatcomplete.azure.deployments]
chatcomplete = ""
embedding = ""
[chatcomplete.bert]
model = "bert-base-chinese"
[chatcomplete.replace.output]
"OpenAI" = "オーペンエーアイ"
"ChatGPT" = "チャットジーピーティー"
"GPT" = "ジーピーティー"
"上下文" = "消息"
"AI" = "虛擬人物程序"
"语言模型" = "虛擬人物程序"
"人工智能程序" = "虛擬人物程序"
"語言模型" = "虛擬人物程序"
"人工智能程式" = "虛擬人物程序"
[chatcomplete.prompts.chat]
system = "You are a writer. You are the writing assistant of the '異世界百科'. Your name is '{bot_name}'. You need to help users complete the characters and settings in their novel."
[chatcomplete.prompts.make_title]
system = "You are a writing assistant, you only need to assist in writing, do not express your opinion."
prompt = "Write a short title in Chinese for the following conversation, don't use quotes:\n\n{content}"
[chatcomplete.prompts.make_summary]
system = "You are a writing assistant, you only need to assist in writing, do not express your opinion. Reply in Chinese."
prompt = "為“{bot_name}”概括下面的聊天記錄排除不重要的對話不要表明自己的意見儘量簡潔。使用中文輸出對話中的“User”是同一個人。\n\n{content}"
[chatcomplete.prompts.extracted_doc]
prompt = "Here are some relevant informations:\n\n{content}"

@ -0,0 +1,56 @@
from __future__ import annotations
from typing import TypeVar
import toml
class Config:
values: dict = {}
@staticmethod
def load_config(file):
with open(file, "r") as f:
Config.values = toml.load(f)
@staticmethod
def get(key: str, default=None, type=None, empty_is_none=False):
key_path = key.split(".")
value = Config.values
for k in key_path:
if k in value:
value = value[k]
else:
return default
if empty_is_none and value == "":
return None
if type == bool:
if isinstance(value, bool):
return value
elif isinstance(value, int) or isinstance(value, float):
return value != 0
else:
return str(value).lower() in ("yes", "true", "1")
elif type == int:
return int(value)
elif type == float:
return float(value)
elif type == str:
return str(value)
elif type == list:
if not isinstance(value, list):
return []
elif type == dict:
if not isinstance(value, dict):
return {}
else:
return value
@staticmethod
def set(key: str, value):
key_path = key.split(".")
obj = Config.values
for k in key_path[:-1]:
if k not in obj:
obj[k] = {}
obj = obj[k]
obj[key_path[-1]] = value

@ -1,7 +1,7 @@
import asyncio import asyncio
import asyncpg import asyncpg
import config
import os import os
from config import Config
conn = None conn = None
@ -9,7 +9,8 @@ class Install:
dbi: asyncpg.Connection dbi: asyncpg.Connection
async def run(self): async def run(self):
self.dbi = await asyncpg.connect(**config.DATABASE) db_config = Config.get("database")
self.dbi = await asyncpg.connect(db_config)
args = os.sys.argv args = os.sys.argv
if "--force" in args: if "--force" in args:
await self.drop_table() await self.drop_table()
@ -21,6 +22,8 @@ class Install:
print("Table dropped") print("Table dropped")
async def create_table(self): async def create_table(self):
embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int)
await self.dbi.execute(""" await self.dbi.execute("""
CREATE TABLE embedding_search_title_index ( CREATE TABLE embedding_search_title_index (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
@ -29,7 +32,7 @@ class Install:
rev_id INT8 NOT NULL, rev_id INT8 NOT NULL,
embedding VECTOR(%d) NOT NULL embedding VECTOR(%d) NOT NULL
); );
""" % (config.EMBEDDING_VECTOR_SIZE)) """ % (embedding_vector_size))
await self.dbi.execute("CREATE INDEX embedding_search_title_index_embedding_idx ON embedding_search_title_index USING ivfflat (embedding vector_cosine_ops);") await self.dbi.execute("CREATE INDEX embedding_search_title_index_embedding_idx ON embedding_search_title_index USING ivfflat (embedding vector_cosine_ops);")
print("Table created") print("Table created")

@ -1,7 +1,8 @@
from local import loop, noawait from local import loop, noawait
from aiohttp import web from aiohttp import web
import config from config import Config
import toml
import api.route import api.route
import utils.web import utils.web
from service.database import DatabaseService from service.database import DatabaseService
@ -22,9 +23,13 @@ async def index(request: web.Request):
async def init_mw_api(app: web.Application): async def init_mw_api(app: web.Application):
mw_api = MediaWikiApi.create() mw_api = MediaWikiApi.create()
if config.MW_BOT_LOGIN_USERNAME and config.MW_BOT_LOGIN_PASSWORD:
bot_username = Config.get("mediawiki.bot_username", "", str)
bot_password = Config.get("mediawiki.bot_password", "", str)
if bot_username and bot_password:
try: try:
await mw_api.robot_login(config.MW_BOT_LOGIN_USERNAME, config.MW_BOT_LOGIN_PASSWORD) await mw_api.robot_login(bot_username, bot_password)
except Exception as e: except Exception as e:
print("Cannot login to Robot account, please check config.") print("Cannot login to Robot account, please check config.")
@ -47,15 +52,17 @@ async def stop_noawait_pool(app: web.Application):
await noawait.end() await noawait.end()
if __name__ == '__main__': if __name__ == '__main__':
Config.load_config("config.toml")
app = web.Application() app = web.Application()
if config.DATABASE: if Config.get("database.host"):
app.on_startup.append(init_database) app.on_startup.append(init_database)
if config.MW_API: if Config.get("mediawiki.api_endpoint"):
app.on_startup.append(init_mw_api) app.on_startup.append(init_mw_api)
if config.OPENAI_TOKEN: if Config.get("chatcomplete.enabled"):
app.on_startup.append(init_tiktoken) app.on_startup.append(init_tiktoken)
app.on_shutdown.append(stop_noawait_pool) app.on_shutdown.append(stop_noawait_pool)
@ -63,4 +70,6 @@ if __name__ == '__main__':
app.router.add_route('*', '/', index) app.router.add_route('*', '/', index)
api.route.init(app) api.route.init(app)
web.run_app(app, host='0.0.0.0', port=config.PORT, loop=loop) server_port = Config.get("port", 8144, int)
web.run_app(app, host='0.0.0.0', port=server_port, loop=loop)

@ -12,6 +12,7 @@ PyJWT==2.6.0
asyncpg-stubs==0.27.0 asyncpg-stubs==0.27.0
sqlalchemy==2.0.17 sqlalchemy==2.0.17
aiohttp-sse-client2==0.3.0 aiohttp-sse-client2==0.3.0
OpenCC==1.1.6 OpenCC==1.1.1
event-emitter-asyncio==1.0.4 event-emitter-asyncio==1.0.4
tiktoken-async==0.3.2 tiktoken-async==0.3.2
toml==0.10.2

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
import time import time
import config from config import Config
import asyncio import asyncio
import random import random
import threading import threading

@ -11,7 +11,7 @@ from api.model.chat_complete.conversation import (
import sys import sys
from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel
import config from config import Config
import utils.config, utils.web import utils.config, utils.web
from aiohttp import web from aiohttp import web
@ -122,9 +122,10 @@ class ChatCompleteService:
else: else:
self.question_tokens = question_tokens self.question_tokens = question_tokens
max_input_tokens = Config.get("chatcomplete.max_input_tokens", 768, int)
if ( if (
len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS len(question) * 4 > max_input_tokens
and self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS and self.question_tokens > max_input_tokens
): ):
# If the question is too long, we need to truncate it # If the question is too long, we need to truncate it
raise web.HTTPRequestEntityTooLarge() raise web.HTTPRequestEntityTooLarge()
@ -174,7 +175,8 @@ class ChatCompleteService:
) )
# If the conversation is too long, we need to make a summary # If the conversation is too long, we need to make a summary
if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS: max_memory_tokens = Config.get("chatcomplete.max_memory_tokens", 1280, int)
if self.conversation_chunk.tokens > max_memory_tokens:
summary, tokens = await self.make_summary( summary, tokens = await self.make_summary(
self.conversation_chunk.message_data self.conversation_chunk.message_data
) )
@ -261,7 +263,7 @@ class ChatCompleteService:
) )
message_log.append({"role": "user", "content": doc_prompt}) message_log.append({"role": "user", "content": doc_prompt})
system_prompt = utils.config.get_prompt("chat", "system_prompt") system_prompt = utils.config.get_prompt("chat", "system")
# Start chat complete # Start chat complete
if on_message is not None: if on_message is not None:
@ -348,22 +350,23 @@ class ChatCompleteService:
async def make_summary(self, message_log_list: list) -> tuple[str, int]: async def make_summary(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = [] chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
for message_data in message_log_list: for message_data in message_log_list:
if message_data["role"] == "summary": if message_data["role"] == "summary":
chat_log.append(message_data["content"]) chat_log.append(message_data["content"])
elif message_data["role"] == "assistant": elif message_data["role"] == "assistant":
chat_log.append( chat_log.append(
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}' f'{bot_name}: {message_data["content"]}'
) )
else: else:
chat_log.append(f'User: {message_data["content"]}') chat_log.append(f'User: {message_data["content"]}')
chat_log_str = "\n".join(chat_log) chat_log_str = "\n".join(chat_log)
summary_system_prompt = utils.config.get_prompt("summary", "system_prompt") summary_system_prompt = utils.config.get_prompt("make_summary", "system")
summary_prompt = utils.config.get_prompt( summary_prompt = utils.config.get_prompt(
"summary", "prompt", {"content": chat_log_str} "make_summary", "prompt", {"content": chat_log_str}
) )
response = await self.openai_api.chat_complete( response = await self.openai_api.chat_complete(
@ -374,19 +377,21 @@ class ChatCompleteService:
async def make_title(self, message_log_list: list) -> tuple[str, int]: async def make_title(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = [] chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
for message_data in message_log_list: for message_data in message_log_list:
if message_data["role"] == "assistant": if message_data["role"] == "assistant":
chat_log.append( chat_log.append(
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}' f'{bot_name}: {message_data["content"]}'
) )
elif message_data["role"] == "user": elif message_data["role"] == "user":
chat_log.append(f'User: {message_data["content"]}') chat_log.append(f'User: {message_data["content"]}')
chat_log_str = "\n".join(chat_log) chat_log_str = "\n".join(chat_log)
title_system_prompt = utils.config.get_prompt("title", "system_prompt") title_system_prompt = utils.config.get_prompt("make_title", "system")
title_prompt = utils.config.get_prompt( title_prompt = utils.config.get_prompt(
"title", "prompt", {"content": chat_log_str} "make_title", "prompt", {"content": chat_log_str}
) )
response = await self.openai_api.chat_complete( response = await self.openai_api.chat_complete(

@ -4,15 +4,16 @@ from urllib.parse import quote_plus
from aiohttp import web from aiohttp import web
import asyncpg import asyncpg
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
import config from config import Config
def get_dsn(): def get_dsn():
db_conf = Config.get("database")
return "postgresql+asyncpg://%s:%s@%s:%s/%s" % ( return "postgresql+asyncpg://%s:%s@%s:%s/%s" % (
quote_plus(config.DATABASE["user"]), quote_plus(db_conf["user"]),
quote_plus(config.DATABASE["password"]), quote_plus(db_conf["password"]),
config.DATABASE["host"], db_conf["host"],
config.DATABASE["port"], db_conf["port"],
quote_plus(config.DATABASE["database"])) quote_plus(db_conf["database"]))
class DatabaseService: class DatabaseService:
instance = None instance = None
@ -38,9 +39,12 @@ class DatabaseService:
self.create_session: async_sessionmaker[AsyncSession] = None self.create_session: async_sessionmaker[AsyncSession] = None
async def init(self): async def init(self):
db_conf = Config.get("database")
debug_mode = Config.get("debug", False, bool)
loop = local.loop loop = local.loop
self.pool = await asyncpg.create_pool(**config.DATABASE, loop=loop) self.pool = await asyncpg.create_pool(**db_conf, loop=loop)
engine = create_async_engine(get_dsn(), echo=config.DEBUG) engine = create_async_engine(get_dsn(), echo=debug_mode)
self.engine = engine self.engine = engine
self.create_session = async_sessionmaker(engine, expire_on_commit=False) self.create_session = async_sessionmaker(engine, expire_on_commit=False)

@ -4,7 +4,10 @@ import sys
import time import time
from typing import Optional, TypedDict from typing import Optional, TypedDict
import aiohttp import aiohttp
import config from config import Config
mw_api = Config.get("mw.api_endpoint", "https://www.isekai.cn/api.php")
request_proxy = Config.get("request.proxy", type=str, empty_is_none=True)
class MediaWikiApiException(Exception): class MediaWikiApiException(Exception):
def __init__(self, info: str, code: Optional[str] = None) -> None: def __init__(self, info: str, code: Optional[str] = None) -> None:
@ -53,7 +56,7 @@ class MediaWikiApi:
@staticmethod @staticmethod
def create(): def create():
if MediaWikiApi.instance is None: if MediaWikiApi.instance is None:
MediaWikiApi.instance = MediaWikiApi(config.MW_API) MediaWikiApi.instance = MediaWikiApi(mw_api)
return MediaWikiApi.instance return MediaWikiApi.instance
@ -74,7 +77,7 @@ class MediaWikiApi:
"titles": title, "titles": title,
"inprop": "url" "inprop": "url"
} }
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -96,7 +99,7 @@ class MediaWikiApi:
"disabletoc": "true", "disabletoc": "true",
"disablelimitreport": "true", "disablelimitreport": "true",
} }
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -112,7 +115,7 @@ class MediaWikiApi:
"meta": "siteinfo|userinfo", "meta": "siteinfo|userinfo",
"siprop": "general" "siprop": "general"
} }
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -145,7 +148,7 @@ class MediaWikiApi:
if start_title is not None: if start_title is not None:
params["apfrom"] = start_title params["apfrom"] = start_title
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -169,7 +172,7 @@ class MediaWikiApi:
"formatversion": "2", "formatversion": "2",
"meta": "userinfo" "meta": "userinfo"
} }
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -185,7 +188,7 @@ class MediaWikiApi:
"meta": "tokens", "meta": "tokens",
"type": token_type "type": token_type
} }
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -207,7 +210,7 @@ class MediaWikiApi:
"lgpassword": password, "lgpassword": password,
"lgtoken": token, "lgtoken": token,
} }
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -239,7 +242,7 @@ class MediaWikiApi:
"namespace": 0, "namespace": 0,
"format": "json", "format": "json",
} }
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
return data[1] return data[1]
@ -254,7 +257,7 @@ class MediaWikiApi:
"format": "json", "format": "json",
"formatversion": "2", "formatversion": "2",
} }
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
if data["error"]["code"] == "user-not-found": if data["error"]["code"] == "user-not-found":
@ -281,7 +284,7 @@ class MediaWikiApi:
} }
# Filter out None values # Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None} post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
print(data) print(data)
@ -307,7 +310,7 @@ class MediaWikiApi:
} }
# Filter out None values # Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None} post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
if data["error"]["code"] == "noenoughpoints": if data["error"]["code"] == "noenoughpoints":
@ -335,7 +338,7 @@ class MediaWikiApi:
} }
# Filter out None values # Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None} post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -360,7 +363,7 @@ class MediaWikiApi:
} }
# Filter out None values # Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None} post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp: async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
if "error" in data: if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])

@ -4,7 +4,7 @@ import json
from typing import Callable, Optional, TypedDict from typing import Callable, Optional, TypedDict
import aiohttp import aiohttp
import config from config import Config
import numpy as np import numpy as np
from aiohttp_sse_client2 import client as sse_client from aiohttp_sse_client2 import client as sse_client
@ -21,21 +21,24 @@ class ChatCompleteResponse(TypedDict):
total_tokens: int total_tokens: int
finish_reason: str finish_reason: str
api_type = Config.get("chatcomplete.api_type", "openai", str)
request_proxy = Config.get("chatcomplete.request_proxy", type=str, empty_is_none=True)
class OpenAIApi: class OpenAIApi:
@staticmethod @staticmethod
def create(): def create():
return OpenAIApi() return OpenAIApi()
def __init__(self): def __init__(self):
if config.OPENAI_API_TYPE == "azure": if api_type == "azure":
self.api_url = config.AZURE_OPENAI_ENDPOINT self.api_url = Config.get("chatcomplete.azure.api_endpoint", type=str)
self.api_key = config.AZURE_OPENAI_KEY self.api_key = Config.get("chatcomplete.azure.key", type=str)
else: else:
self.api_url = config.OPENAI_API or "https://api.openai.com" self.api_url = Config.get("chatcomplete.openai.api_endpoint", type=str)
self.api_key = config.OPENAI_TOKEN self.api_key = Config.get("chatcomplete.openai.key", type=str)
def build_header(self): def build_header(self):
if config.OPENAI_API_TYPE == "azure": if api_type == "azure":
return { return {
"content-type": "application/json", "content-type": "application/json",
"accept": "application/json", "accept": "application/json",
@ -49,11 +52,12 @@ class OpenAIApi:
} }
def get_url(self, method: str): def get_url(self, method: str):
if config.OPENAI_API_TYPE == "azure": if api_type == "azure":
deployments = Config.get("chatcomplete.azure.deployments", type=dict)
if method == "chat/completions": if method == "chat/completions":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME + "/" + method return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method
elif method == "embeddings": elif method == "embeddings":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME + "/" + method return self.api_url + "/openai/deployments/" + deployments["embedding"] + "/" + method
else: else:
return self.api_url + "/v1/" + method return self.api_url + "/v1/" + method
@ -84,13 +88,13 @@ class OpenAIApi:
"input": text_list, "input": text_list,
} }
if config.OPENAI_API_TYPE == "azure": if api_type == "azure":
params["api-version"] = "2023-05-15" params["api-version"] = "2023-05-15"
else: else:
post_data["model"] = "text-embedding-ada-002" post_data["model"] = "text-embedding-ada-002"
if config.OPENAI_API_TYPE == "azure": if api_type == "azure":
# Azure api does not support batch # Azure api does not support batch
for index, text in enumerate(text_list): for index, text in enumerate(text_list):
retry_num = 0 retry_num = 0
@ -102,7 +106,7 @@ class OpenAIApi:
params=params, params=params,
json={"input": text}, json={"input": text},
timeout=30, timeout=30,
proxy=config.REQUEST_PROXY) as resp: proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
@ -134,7 +138,7 @@ class OpenAIApi:
params=params, params=params,
json=post_data, json=post_data,
timeout=30, timeout=30,
proxy=config.REQUEST_PROXY) as resp: proxy=request_proxy) as resp:
data = await resp.json() data = await resp.json()
@ -182,7 +186,7 @@ class OpenAIApi:
"user": user, "user": user,
} }
if config.OPENAI_API_TYPE == "azure": if api_type == "azure":
params["api-version"] = "2023-05-15" params["api-version"] = "2023-05-15"
else: else:
post_data["model"] = "gpt-3.5-turbo" post_data["model"] = "gpt-3.5-turbo"
@ -195,7 +199,7 @@ class OpenAIApi:
params=params, params=params,
json=post_data, json=post_data,
timeout=30, timeout=30,
proxy=config.REQUEST_PROXY) as resp: proxy=api_type) as resp:
data = await resp.json() data = await resp.json()
@ -240,7 +244,7 @@ class OpenAIApi:
"top_p": 0.95 "top_p": 0.95
} }
if config.OPENAI_API_TYPE == "azure": if api_type == "azure":
params["api-version"] = "2023-05-15" params["api-version"] = "2023-05-15"
else: else:
post_data["model"] = "gpt-3.5-turbo" post_data["model"] = "gpt-3.5-turbo"
@ -258,7 +262,7 @@ class OpenAIApi:
headers=self.build_header(), headers=self.build_header(),
params=params, params=params,
json=post_data, json=post_data,
proxy=config.REQUEST_PROXY proxy=request_proxy
) as session: ) as session:
async for event in session: async for event in session:
""" """

@ -1,7 +1,10 @@
import sys import sys
import pathlib import pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent)) root_path = pathlib.Path(__file__).parent.parent
sys.path.append(root_path)
import config from config import Config
config.DEBUG = True
Config.load_config(root_path + "/config.toml")
Config.set("debug", True)

@ -1,11 +1,12 @@
import config from config import Config
def get_prompt(name: str, type: str, params: dict = {}): def get_prompt(name: str, type: str, params: dict = {}):
sys_params = { sys_params = {
"bot_name": config.CHATCOMPLETE_BOT_NAME "bot_name": Config.get("chatcomplete.bot_name", "ChatGPT"),
} }
if name in config.PROMPTS and type in config.PROMPTS[name]: prompts = Config.get("chatcomplete.prompts")
prompt = config.PROMPTS[name][type] if name in prompts and type in prompts[name]:
prompt: str = prompts[name][type]
for key in sys_params: for key in sys_params:
prompt = prompt.replace("{" + key + "}", sys_params[key]) prompt = prompt.replace("{" + key + "}", sys_params[key])

@ -4,8 +4,8 @@ import json
from typing import Any, Optional, Dict from typing import Any, Optional, Dict
from aiohttp import web from aiohttp import web
import jwt import jwt
import config
import uuid import uuid
from config import Config
ParamRule = Dict[str, Any] ParamRule = Dict[str, Any]
@ -119,6 +119,7 @@ def token_auth(f):
@wraps(f) @wraps(f)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
auth_tokens: dict = Config.get("authorization")
request: web.Request = args[0] request: web.Request = args[0]
jwt_token = None jwt_token = None
@ -141,7 +142,7 @@ def token_auth(f):
jwt_token = token jwt_token = token
if sk_token is not None: if sk_token is not None:
if sk_token not in config.AUTH_TOKENS.values(): if sk_token not in auth_tokens.values():
return await api_response(status=-1, error={ return await api_response(status=-1, error={
"code": "token-invalid", "code": "token-invalid",
"target": "token_id", "target": "token_id",
@ -168,7 +169,7 @@ def token_auth(f):
# Check jwt # Check jwt
try: try:
data = jwt.decode(jwt_token, config.AUTH_TOKENS[key_id], algorithms=['HS256', 'HS384', 'HS512']) data = jwt.decode(jwt_token, auth_tokens[key_id], algorithms=['HS256', 'HS384', 'HS512'])
if "sub" not in data: if "sub" not in data:
return await api_response(status=-1, error={ return await api_response(status=-1, error={
"code": "token-invalid", "code": "token-invalid",

Loading…
Cancel
Save