更新配置文件库

master
落雨楓 2 years ago
parent 74e8f9c7b0
commit 8ba74fa952

2
.gitignore vendored

@ -141,4 +141,4 @@ dmypy.json
# Cython debug symbols
cython_debug/
/config.py
/config.toml

@ -0,0 +1,43 @@
from __future__ import annotations
import sqlalchemy
from api.model.base import BaseHelper, BaseModel
import sqlalchemy
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, Mapped
class BotPersonaCategoryModel(BaseModel):
__tablename__ = "chat_complete_bot_persona_category"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
font_icon: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
color: Mapped[str] = mapped_column(sqlalchemy.String(10), nullable=True)
order: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
class BotPersonaHelper(BaseHelper):
async def add(self, obj: BotPersonaCategoryModel):
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj
async def update(self, obj: BotPersonaCategoryModel):
obj = await self.session.merge(obj)
await self.session.commit()
return obj
async def remove(self, id: int):
stmt = sqlalchemy.delete(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id)
await self.session.execute(stmt)
await self.session.commit()
async def get_list(self):
stmt = select(BotPersonaCategoryModel)
return await self.session.scalars(stmt)
async def find_by_id(self, id: int):
stmt = select(BotPersonaCategoryModel).where(BotPersonaCategoryModel.id == id)
return await self.session.scalar(stmt)

@ -4,9 +4,9 @@ from typing import Optional, Type
import asyncpg
from api.model.base import BaseModel
import config
import numpy as np
import sqlalchemy
from config import Config
from sqlalchemy import Index, select, update, delete, Select
from sqlalchemy.orm import mapped_column, Mapped
from sqlalchemy.ext.asyncio import AsyncSession
@ -17,6 +17,7 @@ from service.database import DatabaseService
page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}
embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int)
class AbstractPageIndexModel(BaseModel):
__abstract__ = True
@ -26,7 +27,7 @@ class AbstractPageIndexModel(BaseModel):
)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
embedding: Mapped[np.ndarray] = mapped_column(Vector(embedding_vector_size))
text: Mapped[str] = mapped_column(sqlalchemy.Text)
text_len: Mapped[int] = mapped_column(sqlalchemy.Integer)
markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True)

@ -8,10 +8,12 @@ import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer
from sqlalchemy.ext.asyncio import AsyncEngine
import config
from config import Config
from api.model.base import BaseHelper, BaseModel
from service.database import DatabaseService
embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int)
class TitleIndexModel(BaseModel):
__tablename__ = "embedding_search_title_index"
@ -22,7 +24,7 @@ class TitleIndexModel(BaseModel):
collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
indexed_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE), nullable=True))
embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(embedding_vector_size), nullable=True))
embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding,
postgresql_using='ivfflat',

@ -1,71 +0,0 @@
PORT = 8144
HOST = "www.isekai.cn"
MW_API = "http://dev.isekai.cn/api.php"
DEBUG = True
DATABASE = {
"host": "127.0.0.1",
"database": "isekai_toolkit",
"user": "",
"password": "",
"port": "5432",
}
EMBEDDING_VECTOR_SIZE = 1536
OPENAI_API_TYPE = "openai" # openai or azure
OPENAI_API = "https://api.openai.com"
OPENAI_TOKEN = "sk-"
OPENAI_API = None
OPENAI_TOKEN = ""
AZURE_OPENAI_ENDPOINT = "https://your-instance.openai.azure.com"
AZURE_OPENAI_KEY = ""
AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME = ""
AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME = ""
CHATCOMPLETE_MAX_MEMORY_TOKENS = 1024
CHATCOMPLETE_MAX_INPUT_TOKENS = 768
CHATCOMPLETE_OUTPUT_REPLACE = {
"OpenAI": "オーペンエーアイ",
"ChatGPT": "チャットジーピーティー",
"GPT": "ジーピーティー",
"上下文": "消息",
"AI": "虛擬人物程序",
"语言模型": "虛擬人物程序",
"人工智能程序": "虛擬人物程序",
"語言模型": "虛擬人物程序",
"人工智能程式": "虛擬人物程序",
}
CHATCOMPLETE_BOT_NAME = "寫作助手"
PROMPTS = {
"chat": {
"system_prompt": "You are a writer. You are the writing assistant of the '異世界百科'. Your name is '{bot_name}'. You need to help users complete the characters and settings in their novel.",
},
"title": {
"system_prompt": "You are a writing assistant, you only need to assist in writing, do not express your opinion.",
"prompt": "Write a short title in Chinese for the following conversation, don't use quotes:\n\n{content}"
},
"suggestions": {
"prompt": "根據下面的對話,提出幾個問題:\n\n{content}"
},
"summary": {
"system_prompt": "You are a writing assistant, you only need to assist in writing, do not express your opinion. Output in Chinese.",
"prompt": "為“{bot_name}”概括下面的聊天記錄排除不重要的對話不要表明自己的意見儘量簡潔。使用中文輸出“User”是同一個人。\n\n{content}"
},
"extracted_doc": {
"prompt": "Here are some relevant informations:\n\n{content}"
}
}
REQUEST_PROXY = "http://127.0.0.1:7890"
AUTH_TOKENS = {
"isekaiwiki": "sk-123456"
}
MW_BOT_LOGIN_USERNAME = "Hyperzlib@ChatComplete"
MW_BOT_LOGIN_PASSWORD = ""

@ -0,0 +1,75 @@
[server]
port = 8144
host = "www.isekai.cn"
debug = false
[database]
host = ""
port = 5432
user = ""
password = ""
database = ""
[mediawiki]
api_endpoint = "http://dev.isekai.cn/api.php"
bot_username = ""
bot_password = ""
[request]
proxy = ""
[authorization]
isekaiwiki = "sk-123456"
[chatcomplete]
enabled = true
bot_name = "寫作助手"
embedding_vector_size = 1536
embedding_type = "openai"
max_memory_tokens = 1280
max_input_tokens = 128
api_type = "azure"
[chatcomplete.openai]
api_endpoint = "https://api.openai.com"
key = "sk-"
[chatcomplete.azure]
api_endpoint = "https://your-instance.openai.azure.com"
key = ""
[chatcomplete.azure.deployments]
chatcomplete = ""
embedding = ""
[chatcomplete.bert]
model = "bert-base-chinese"
[chatcomplete.replace.output]
"OpenAI" = "オーペンエーアイ"
"ChatGPT" = "チャットジーピーティー"
"GPT" = "ジーピーティー"
"上下文" = "消息"
"AI" = "虛擬人物程序"
"语言模型" = "虛擬人物程序"
"人工智能程序" = "虛擬人物程序"
"語言模型" = "虛擬人物程序"
"人工智能程式" = "虛擬人物程序"
[chatcomplete.prompts.chat]
system = "You are a writer. You are the writing assistant of the '異世界百科'. Your name is '{bot_name}'. You need to help users complete the characters and settings in their novel."
[chatcomplete.prompts.make_title]
system = "You are a writing assistant, you only need to assist in writing, do not express your opinion."
prompt = "Write a short title in Chinese for the following conversation, don't use quotes:\n\n{content}"
[chatcomplete.prompts.make_summary]
system = "You are a writing assistant, you only need to assist in writing, do not express your opinion. Reply in Chinese."
prompt = "為“{bot_name}”概括下面的聊天記錄排除不重要的對話不要表明自己的意見儘量簡潔。使用中文輸出對話中的“User”是同一個人。\n\n{content}"
[chatcomplete.prompts.extracted_doc]
prompt = "Here are some relevant informations:\n\n{content}"

@ -0,0 +1,56 @@
from __future__ import annotations
from typing import TypeVar
import toml
class Config:
values: dict = {}
@staticmethod
def load_config(file):
with open(file, "r") as f:
Config.values = toml.load(f)
@staticmethod
def get(key: str, default=None, type=None, empty_is_none=False):
key_path = key.split(".")
value = Config.values
for k in key_path:
if k in value:
value = value[k]
else:
return default
if empty_is_none and value == "":
return None
if type == bool:
if isinstance(value, bool):
return value
elif isinstance(value, int) or isinstance(value, float):
return value != 0
else:
return str(value).lower() in ("yes", "true", "1")
elif type == int:
return int(value)
elif type == float:
return float(value)
elif type == str:
return str(value)
elif type == list:
if not isinstance(value, list):
return []
elif type == dict:
if not isinstance(value, dict):
return {}
else:
return value
@staticmethod
def set(key: str, value):
key_path = key.split(".")
obj = Config.values
for k in key_path[:-1]:
if k not in obj:
obj[k] = {}
obj = obj[k]
obj[key_path[-1]] = value

@ -1,7 +1,7 @@
import asyncio
import asyncpg
import config
import os
from config import Config
conn = None
@ -9,7 +9,8 @@ class Install:
dbi: asyncpg.Connection
async def run(self):
self.dbi = await asyncpg.connect(**config.DATABASE)
db_config = Config.get("database")
self.dbi = await asyncpg.connect(db_config)
args = os.sys.argv
if "--force" in args:
await self.drop_table()
@ -21,6 +22,8 @@ class Install:
print("Table dropped")
async def create_table(self):
embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int)
await self.dbi.execute("""
CREATE TABLE embedding_search_title_index (
id SERIAL PRIMARY KEY,
@ -29,7 +32,7 @@ class Install:
rev_id INT8 NOT NULL,
embedding VECTOR(%d) NOT NULL
);
""" % (config.EMBEDDING_VECTOR_SIZE))
""" % (embedding_vector_size))
await self.dbi.execute("CREATE INDEX embedding_search_title_index_embedding_idx ON embedding_search_title_index USING ivfflat (embedding vector_cosine_ops);")
print("Table created")

@ -1,7 +1,8 @@
from local import loop, noawait
from aiohttp import web
import config
from config import Config
import toml
import api.route
import utils.web
from service.database import DatabaseService
@ -22,9 +23,13 @@ async def index(request: web.Request):
async def init_mw_api(app: web.Application):
mw_api = MediaWikiApi.create()
if config.MW_BOT_LOGIN_USERNAME and config.MW_BOT_LOGIN_PASSWORD:
bot_username = Config.get("mediawiki.bot_username", "", str)
bot_password = Config.get("mediawiki.bot_password", "", str)
if bot_username and bot_password:
try:
await mw_api.robot_login(config.MW_BOT_LOGIN_USERNAME, config.MW_BOT_LOGIN_PASSWORD)
await mw_api.robot_login(bot_username, bot_password)
except Exception as e:
print("Cannot login to Robot account, please check config.")
@ -47,15 +52,17 @@ async def stop_noawait_pool(app: web.Application):
await noawait.end()
if __name__ == '__main__':
Config.load_config("config.toml")
app = web.Application()
if config.DATABASE:
if Config.get("database.host"):
app.on_startup.append(init_database)
if config.MW_API:
if Config.get("mediawiki.api_endpoint"):
app.on_startup.append(init_mw_api)
if config.OPENAI_TOKEN:
if Config.get("chatcomplete.enabled"):
app.on_startup.append(init_tiktoken)
app.on_shutdown.append(stop_noawait_pool)
@ -63,4 +70,6 @@ if __name__ == '__main__':
app.router.add_route('*', '/', index)
api.route.init(app)
web.run_app(app, host='0.0.0.0', port=config.PORT, loop=loop)
server_port = Config.get("port", 8144, int)
web.run_app(app, host='0.0.0.0', port=server_port, loop=loop)

@ -12,6 +12,7 @@ PyJWT==2.6.0
asyncpg-stubs==0.27.0
sqlalchemy==2.0.17
aiohttp-sse-client2==0.3.0
OpenCC==1.1.6
OpenCC==1.1.1
event-emitter-asyncio==1.0.4
tiktoken-async==0.3.2
tiktoken-async==0.3.2
toml==0.10.2

@ -1,6 +1,6 @@
from __future__ import annotations
import time
import config
from config import Config
import asyncio
import random
import threading

@ -11,7 +11,7 @@ from api.model.chat_complete.conversation import (
import sys
from api.model.toolkit_ui.conversation import ConversationHelper, ConversationModel
import config
from config import Config
import utils.config, utils.web
from aiohttp import web
@ -122,9 +122,10 @@ class ChatCompleteService:
else:
self.question_tokens = question_tokens
max_input_tokens = Config.get("chatcomplete.max_input_tokens", 768, int)
if (
len(question) * 4 > config.CHATCOMPLETE_MAX_INPUT_TOKENS
and self.question_tokens > config.CHATCOMPLETE_MAX_INPUT_TOKENS
len(question) * 4 > max_input_tokens
and self.question_tokens > max_input_tokens
):
# If the question is too long, we need to truncate it
raise web.HTTPRequestEntityTooLarge()
@ -174,7 +175,8 @@ class ChatCompleteService:
)
# If the conversation is too long, we need to make a summary
if self.conversation_chunk.tokens > config.CHATCOMPLETE_MAX_MEMORY_TOKENS:
max_memory_tokens = Config.get("chatcomplete.max_memory_tokens", 1280, int)
if self.conversation_chunk.tokens > max_memory_tokens:
summary, tokens = await self.make_summary(
self.conversation_chunk.message_data
)
@ -261,7 +263,7 @@ class ChatCompleteService:
)
message_log.append({"role": "user", "content": doc_prompt})
system_prompt = utils.config.get_prompt("chat", "system_prompt")
system_prompt = utils.config.get_prompt("chat", "system")
# Start chat complete
if on_message is not None:
@ -348,22 +350,23 @@ class ChatCompleteService:
async def make_summary(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
for message_data in message_log_list:
if message_data["role"] == "summary":
chat_log.append(message_data["content"])
elif message_data["role"] == "assistant":
chat_log.append(
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}'
f'{bot_name}: {message_data["content"]}'
)
else:
chat_log.append(f'User: {message_data["content"]}')
chat_log_str = "\n".join(chat_log)
summary_system_prompt = utils.config.get_prompt("summary", "system_prompt")
summary_system_prompt = utils.config.get_prompt("make_summary", "system")
summary_prompt = utils.config.get_prompt(
"summary", "prompt", {"content": chat_log_str}
"make_summary", "prompt", {"content": chat_log_str}
)
response = await self.openai_api.chat_complete(
@ -374,19 +377,21 @@ class ChatCompleteService:
async def make_title(self, message_log_list: list) -> tuple[str, int]:
chat_log: list[str] = []
bot_name = Config.get("chatcomplete.bot_name", "ChatComplete", str)
for message_data in message_log_list:
if message_data["role"] == "assistant":
chat_log.append(
f'{config.CHATCOMPLETE_BOT_NAME}: {message_data["content"]}'
f'{bot_name}: {message_data["content"]}'
)
elif message_data["role"] == "user":
chat_log.append(f'User: {message_data["content"]}')
chat_log_str = "\n".join(chat_log)
title_system_prompt = utils.config.get_prompt("title", "system_prompt")
title_system_prompt = utils.config.get_prompt("make_title", "system")
title_prompt = utils.config.get_prompt(
"title", "prompt", {"content": chat_log_str}
"make_title", "prompt", {"content": chat_log_str}
)
response = await self.openai_api.chat_complete(

@ -4,15 +4,16 @@ from urllib.parse import quote_plus
from aiohttp import web
import asyncpg
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
import config
from config import Config
def get_dsn():
db_conf = Config.get("database")
return "postgresql+asyncpg://%s:%s@%s:%s/%s" % (
quote_plus(config.DATABASE["user"]),
quote_plus(config.DATABASE["password"]),
config.DATABASE["host"],
config.DATABASE["port"],
quote_plus(config.DATABASE["database"]))
quote_plus(db_conf["user"]),
quote_plus(db_conf["password"]),
db_conf["host"],
db_conf["port"],
quote_plus(db_conf["database"]))
class DatabaseService:
instance = None
@ -38,9 +39,12 @@ class DatabaseService:
self.create_session: async_sessionmaker[AsyncSession] = None
async def init(self):
db_conf = Config.get("database")
debug_mode = Config.get("debug", False, bool)
loop = local.loop
self.pool = await asyncpg.create_pool(**config.DATABASE, loop=loop)
self.pool = await asyncpg.create_pool(**db_conf, loop=loop)
engine = create_async_engine(get_dsn(), echo=config.DEBUG)
engine = create_async_engine(get_dsn(), echo=debug_mode)
self.engine = engine
self.create_session = async_sessionmaker(engine, expire_on_commit=False)

@ -4,7 +4,10 @@ import sys
import time
from typing import Optional, TypedDict
import aiohttp
import config
from config import Config
mw_api = Config.get("mw.api_endpoint", "https://www.isekai.cn/api.php")
request_proxy = Config.get("request.proxy", type=str, empty_is_none=True)
class MediaWikiApiException(Exception):
def __init__(self, info: str, code: Optional[str] = None) -> None:
@ -53,7 +56,7 @@ class MediaWikiApi:
@staticmethod
def create():
if MediaWikiApi.instance is None:
MediaWikiApi.instance = MediaWikiApi(config.MW_API)
MediaWikiApi.instance = MediaWikiApi(mw_api)
return MediaWikiApi.instance
@ -74,7 +77,7 @@ class MediaWikiApi:
"titles": title,
"inprop": "url"
}
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -96,7 +99,7 @@ class MediaWikiApi:
"disabletoc": "true",
"disablelimitreport": "true",
}
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -112,7 +115,7 @@ class MediaWikiApi:
"meta": "siteinfo|userinfo",
"siprop": "general"
}
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -145,7 +148,7 @@ class MediaWikiApi:
if start_title is not None:
params["apfrom"] = start_title
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -169,7 +172,7 @@ class MediaWikiApi:
"formatversion": "2",
"meta": "userinfo"
}
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -185,7 +188,7 @@ class MediaWikiApi:
"meta": "tokens",
"type": token_type
}
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -207,7 +210,7 @@ class MediaWikiApi:
"lgpassword": password,
"lgtoken": token,
}
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -239,7 +242,7 @@ class MediaWikiApi:
"namespace": 0,
"format": "json",
}
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json()
return data[1]
@ -254,7 +257,7 @@ class MediaWikiApi:
"format": "json",
"formatversion": "2",
}
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
if data["error"]["code"] == "user-not-found":
@ -281,7 +284,7 @@ class MediaWikiApi:
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
print(data)
@ -307,7 +310,7 @@ class MediaWikiApi:
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
if data["error"]["code"] == "noenoughpoints":
@ -335,7 +338,7 @@ class MediaWikiApi:
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -360,7 +363,7 @@ class MediaWikiApi:
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=config.REQUEST_PROXY) as resp:
async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])

@ -4,7 +4,7 @@ import json
from typing import Callable, Optional, TypedDict
import aiohttp
import config
from config import Config
import numpy as np
from aiohttp_sse_client2 import client as sse_client
@ -21,21 +21,24 @@ class ChatCompleteResponse(TypedDict):
total_tokens: int
finish_reason: str
api_type = Config.get("chatcomplete.api_type", "openai", str)
request_proxy = Config.get("chatcomplete.request_proxy", type=str, empty_is_none=True)
class OpenAIApi:
@staticmethod
def create():
return OpenAIApi()
def __init__(self):
if config.OPENAI_API_TYPE == "azure":
self.api_url = config.AZURE_OPENAI_ENDPOINT
self.api_key = config.AZURE_OPENAI_KEY
if api_type == "azure":
self.api_url = Config.get("chatcomplete.azure.api_endpoint", type=str)
self.api_key = Config.get("chatcomplete.azure.key", type=str)
else:
self.api_url = config.OPENAI_API or "https://api.openai.com"
self.api_key = config.OPENAI_TOKEN
self.api_url = Config.get("chatcomplete.openai.api_endpoint", type=str)
self.api_key = Config.get("chatcomplete.openai.key", type=str)
def build_header(self):
if config.OPENAI_API_TYPE == "azure":
if api_type == "azure":
return {
"content-type": "application/json",
"accept": "application/json",
@ -49,11 +52,12 @@ class OpenAIApi:
}
def get_url(self, method: str):
if config.OPENAI_API_TYPE == "azure":
if api_type == "azure":
deployments = Config.get("chatcomplete.azure.deployments", type=dict)
if method == "chat/completions":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_CHATCOMPLETE_DEPLOYMENT_NAME + "/" + method
return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method
elif method == "embeddings":
return self.api_url + "/openai/deployments/" + config.AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME + "/" + method
return self.api_url + "/openai/deployments/" + deployments["embedding"] + "/" + method
else:
return self.api_url + "/v1/" + method
@ -84,13 +88,13 @@ class OpenAIApi:
"input": text_list,
}
if config.OPENAI_API_TYPE == "azure":
if api_type == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "text-embedding-ada-002"
if config.OPENAI_API_TYPE == "azure":
if api_type == "azure":
# Azure api does not support batch
for index, text in enumerate(text_list):
retry_num = 0
@ -102,7 +106,7 @@ class OpenAIApi:
params=params,
json={"input": text},
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
proxy=request_proxy) as resp:
data = await resp.json()
@ -134,7 +138,7 @@ class OpenAIApi:
params=params,
json=post_data,
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
proxy=request_proxy) as resp:
data = await resp.json()
@ -182,7 +186,7 @@ class OpenAIApi:
"user": user,
}
if config.OPENAI_API_TYPE == "azure":
if api_type == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "gpt-3.5-turbo"
@ -195,7 +199,7 @@ class OpenAIApi:
params=params,
json=post_data,
timeout=30,
proxy=config.REQUEST_PROXY) as resp:
proxy=api_type) as resp:
data = await resp.json()
@ -240,7 +244,7 @@ class OpenAIApi:
"top_p": 0.95
}
if config.OPENAI_API_TYPE == "azure":
if api_type == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "gpt-3.5-turbo"
@ -258,7 +262,7 @@ class OpenAIApi:
headers=self.build_header(),
params=params,
json=post_data,
proxy=config.REQUEST_PROXY
proxy=request_proxy
) as session:
async for event in session:
"""

@ -1,7 +1,10 @@
import sys
import pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))
root_path = pathlib.Path(__file__).parent.parent
sys.path.append(root_path)
import config
config.DEBUG = True
from config import Config
Config.load_config(root_path + "/config.toml")
Config.set("debug", True)

@ -1,11 +1,12 @@
import config
from config import Config
def get_prompt(name: str, type: str, params: dict = {}):
sys_params = {
"bot_name": config.CHATCOMPLETE_BOT_NAME
"bot_name": Config.get("chatcomplete.bot_name", "ChatGPT"),
}
if name in config.PROMPTS and type in config.PROMPTS[name]:
prompt = config.PROMPTS[name][type]
prompts = Config.get("chatcomplete.prompts")
if name in prompts and type in prompts[name]:
prompt: str = prompts[name][type]
for key in sys_params:
prompt = prompt.replace("{" + key + "}", sys_params[key])

@ -4,8 +4,8 @@ import json
from typing import Any, Optional, Dict
from aiohttp import web
import jwt
import config
import uuid
from config import Config
ParamRule = Dict[str, Any]
@ -119,6 +119,7 @@ def token_auth(f):
@wraps(f)
def decorated_function(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
auth_tokens: dict = Config.get("authorization")
request: web.Request = args[0]
jwt_token = None
@ -141,7 +142,7 @@ def token_auth(f):
jwt_token = token
if sk_token is not None:
if sk_token not in config.AUTH_TOKENS.values():
if sk_token not in auth_tokens.values():
return await api_response(status=-1, error={
"code": "token-invalid",
"target": "token_id",
@ -168,7 +169,7 @@ def token_auth(f):
# Check jwt
try:
data = jwt.decode(jwt_token, config.AUTH_TOKENS[key_id], algorithms=['HS256', 'HS384', 'HS512'])
data = jwt.decode(jwt_token, auth_tokens[key_id], algorithms=['HS256', 'HS384', 'HS512'])
if "sub" not in data:
return await api_response(status=-1, error={
"code": "token-invalid",

Loading…
Cancel
Save