完成机器人选择

master
落雨楓 2 years ago
parent e761d4dbcb
commit 9ab83270e7

@ -309,8 +309,19 @@ class ChatComplete:
persona_list = await bot_persona_helper.get_list(page=page, category_id=category_id)
page_count = await bot_persona_helper.get_page_count(category_id=category_id)
persona_data_list = []
for persona in persona_list:
persona_data_list.append({
"id": persona.id,
"bot_id": persona.bot_id,
"bot_name": persona.bot_name,
"bot_avatar": persona.bot_avatar,
"bot_description": persona.bot_description,
"updated_at": persona.updated_at,
})
return await utils.web.api_response(1, {
"list": persona_list,
"list": persona_data_list,
"page_count": page_count,
}, request=request)
@ -341,7 +352,12 @@ class ChatComplete:
"message": "Invalid params. Please specify id or bot_id."
}, http_status=400, request=request)
return await utils.web.api_response(1, persona_info, request=request)
persona_info_res = {}
for key, value in persona_info.__dict__.items():
if not key.startswith("_"):
persona_info_res[key] = value
return await utils.web.api_response(1, persona_info_res, request=request)
@staticmethod
@utils.web.token_auth
@ -359,6 +375,10 @@ class ChatComplete:
"type": int,
"required": False,
},
"bot_id": {
"type": str,
"required": False,
},
"extract_limit": {
"type": int,
"required": False,
@ -381,6 +401,7 @@ class ChatComplete:
page_title = params.get("title")
question = params.get("question")
conversation_id = params.get("conversation_id")
bot_id = params.get("bot_id")
extract_limit = params.get("extract_limit")
in_collection = params.get("in_collection")
@ -392,6 +413,7 @@ class ChatComplete:
try:
chat_complete_task = ChatCompleteTask(dbs, user_id, page_title, caller != "user")
init_res = await chat_complete_task.init(question, conversation_id=conversation_id, edit_message_id=edit_message_id,
bot_id=bot_id,
embedding_search={
"limit": extract_limit,
"in_collection": in_collection,

@ -4,7 +4,11 @@ import time
import traceback
from local import noawait
from typing import Optional, Callable, Union
from service.chat_complete import ChatCompleteService, ChatCompleteServicePrepareResponse, ChatCompleteServiceResponse
from service.chat_complete import (
ChatCompleteService,
ChatCompleteServicePrepareResponse,
ChatCompleteServiceResponse,
)
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchArgs
from service.mediawiki_api import MediaWikiApi, MediaWikiPageNotFoundException
@ -13,12 +17,15 @@ import utils.web
chat_complete_tasks: dict[str, ChatCompleteTask] = {}
class ChatCompleteTask:
@staticmethod
def get_by_id(task_id: str) -> Union[ChatCompleteTask, None]:
return chat_complete_tasks.get(task_id)
def __init__(self, dbs: DatabaseService, user_id: int, page_title: str, is_system = False):
def __init__(
self, dbs: DatabaseService, user_id: int, page_title: str, is_system=False
):
self.task_id = utils.web.generate_uuid()
self.on_message: list[Callable] = []
self.on_finished: list[Callable] = []
@ -41,8 +48,14 @@ class ChatCompleteTask:
self.result: Optional[ChatCompleteServiceResponse] = None
self.error: Optional[Exception] = None
async def init(self, question: str, conversation_id: Optional[str] = None, edit_message_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None) -> ChatCompleteServicePrepareResponse:
async def init(
self,
question: str,
conversation_id: Optional[str] = None,
edit_message_id: Optional[str] = None,
bot_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None,
) -> ChatCompleteServicePrepareResponse:
self.tiktoken = await TikTokenService.create()
self.mwapi = MediaWikiApi.create()
@ -59,17 +72,26 @@ class ChatCompleteTask:
self.transatcion_id: Optional[str] = None
self.point_cost: int = 0
if not self.is_system:
usage_res = await self.mwapi.ai_toolbox_start_transaction(self.user_id, "chatcomplete",
question_tokens, extract_limit)
usage_res = await self.mwapi.ai_toolbox_start_transaction(
self.user_id, "chatcomplete", question_tokens, extract_limit
)
self.transatcion_id = usage_res["transaction_id"]
self.point_cost = usage_res["point_cost"]
res = await self.chat_complete.prepare_chat_complete(question, conversation_id=conversation_id,
user_id=self.user_id, edit_message_id=edit_message_id, embedding_search=embedding_search)
res = await self.chat_complete.prepare_chat_complete(
question,
conversation_id=conversation_id,
user_id=self.user_id,
edit_message_id=edit_message_id,
bot_id=bot_id,
embedding_search=embedding_search,
)
return res
else:
raise MediaWikiPageNotFoundException("Page %s not found." % self.page_title)
raise MediaWikiPageNotFoundException(
"Page %s not found." % self.page_title
)
except Exception as e:
await self.end()
raise e
@ -81,7 +103,10 @@ class ChatCompleteTask:
try:
await callback(delta_message)
except Exception as e:
print("Error while processing on_message callback: %s" % e, file=sys.stderr)
print(
"Error while processing on_message callback: %s" % e,
file=sys.stderr,
)
traceback.print_exc()
async def _on_finished(self):
@ -89,7 +114,10 @@ class ChatCompleteTask:
try:
await callback(self.result)
except Exception as e:
print("Error while processing on_finished callback: %s" % e, file=sys.stderr)
print(
"Error while processing on_finished callback: %s" % e,
file=sys.stderr,
)
traceback.print_exc()
async def _on_error(self, err: Exception):
@ -98,7 +126,9 @@ class ChatCompleteTask:
try:
await callback(err)
except Exception as e:
print("Error while processing on_error callback: %s" % e, file=sys.stderr)
print(
"Error while processing on_error callback: %s" % e, file=sys.stderr
)
traceback.print_exc()
async def run(self) -> ChatCompleteServiceResponse:
@ -111,7 +141,9 @@ class ChatCompleteTask:
self.result = chat_res
if self.transatcion_id:
await self.mwapi.ai_toolbox_end_transaction(self.transatcion_id, chat_res["total_tokens"])
await self.mwapi.ai_toolbox_end_transaction(
self.transatcion_id, chat_res["total_tokens"]
)
await self._on_finished()
except Exception as e:
@ -121,7 +153,9 @@ class ChatCompleteTask:
traceback.print_exc()
if self.transatcion_id:
await self.mwapi.ai_toolbox_cancel_transaction(self.transatcion_id, error=err_msg)
await self.mwapi.ai_toolbox_cancel_transaction(
self.transatcion_id, error=err_msg
)
await self._on_error(e)
finally:
@ -134,13 +168,20 @@ class ChatCompleteTask:
self.is_finished = True
self.finished_time = time.time()
TASK_EXPIRE_TIME = 60 * 10
async def chat_complete_task_gc():
now = time.time()
for task_id in chat_complete_tasks.keys():
task = chat_complete_tasks[task_id]
if task.is_finished and task.finished_time is not None and now > task.finished_time + TASK_EXPIRE_TIME:
if (
task.is_finished
and task.finished_time is not None
and now > task.finished_time + TASK_EXPIRE_TIME
):
del chat_complete_tasks[task_id]
noawait.add_timer(chat_complete_task_gc, 60)

@ -11,21 +11,29 @@ from sqlalchemy.orm import mapped_column, relationship, load_only, Mapped
from api.model.chat_complete.bot_persona_category import BotPersonaCategoryModel
from service.database import DatabaseService
class BotPersonaModel(BaseModel):
__tablename__ = "chat_complete_bot_persona"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
id: Mapped[int] = mapped_column(
sqlalchemy.Integer, primary_key=True, autoincrement=True
)
bot_id: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
category_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey(BotPersonaCategoryModel.id, ondelete="CASCADE", onupdate="CASCADE"), index=True)
sqlalchemy.ForeignKey(
BotPersonaCategoryModel.id, ondelete="CASCADE", onupdate="CASCADE"
),
index=True,
)
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
updated_at: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
class BotPersonaHelper(BaseHelper):
async def add(self, obj: BotPersonaModel):
self.session.add(obj)
@ -38,14 +46,30 @@ class BotPersonaHelper(BaseHelper):
await self.session.commit()
return obj
async def get_list(self, page: Optional[int] = 1, page_size: Optional[int] = 20, category_id: Optional[int] = None):
async def get_list(
self,
page: Optional[int] = 1,
page_size: Optional[int] = 20,
category_id: Optional[int] = None,
):
offset_index = (page - 1) * page_size
stmt = select(BotPersonaModel) \
.options(load_only("id", "bot_id", "bot_name", "bot_avatar", "bot_description", "updated_at")) \
.order_by(BotPersonaModel.updated_at.desc()) \
.offset(offset_index) \
stmt = (
select(BotPersonaModel)
.options(
load_only(
BotPersonaModel.id,
BotPersonaModel.bot_id,
BotPersonaModel.bot_name,
BotPersonaModel.bot_avatar,
BotPersonaModel.bot_description,
BotPersonaModel.updated_at,
)
)
.order_by(BotPersonaModel.updated_at.desc())
.offset(offset_index)
.limit(page_size)
)
if category_id is not None:
stmt = stmt.where(BotPersonaModel.category_id == category_id)
@ -73,7 +97,9 @@ class BotPersonaHelper(BaseHelper):
return await self.session.scalar(stmt)
async def get_system_prompt(self, bot_id: str) -> str | None:
stmt = select(BotPersonaModel.system_prompt).where(BotPersonaModel.bot_id == bot_id)
stmt = select(BotPersonaModel.system_prompt).where(
BotPersonaModel.bot_id == bot_id
)
return await self.session.scalar(stmt)
@staticmethod

@ -1,7 +1,8 @@
from local import loop, noawait
from config import Config
Config.load_config("config.toml")
from local import loop, noawait
from aiohttp import web
from config import Config
import local
import api.route
import utils.web
@ -54,8 +55,6 @@ async def stop_noawait_pool(app: web.Application):
await noawait.end()
if __name__ == '__main__':
Config.load_config("config.toml")
local.debug = Config.get("server.debug", False, bool)
app = web.Application()

@ -50,6 +50,8 @@ class ChatCompleteService:
self.embedding_search = EmbeddingSearchService(dbs, title)
self.conversation_helper = ConversationHelper(dbs)
self.conversation_chunk_helper = ConversationChunkHelper(dbs)
self.bot_persona_helper = BotPersonaHelper(dbs)
self.conversation_info: Optional[ConversationModel] = None
self.conversation_chunk: Optional[ConversationChunkModel] = None
@ -64,9 +66,12 @@ class ChatCompleteService:
self.user_id = 0
self.question = ""
self.question_tokens: Optional[int] = None
self.bot_id: str = ""
self.conversation_id: Optional[int] = None
self.conversation_start_time: Optional[int] = None
self.chat_system_prompt = ""
self.delta_data = {}
async def __aenter__(self):
@ -75,6 +80,7 @@ class ChatCompleteService:
await self.embedding_search.__aenter__()
await self.conversation_helper.__aenter__()
await self.conversation_chunk_helper.__aenter__()
await self.bot_persona_helper.__aenter__()
return self
@ -82,6 +88,7 @@ class ChatCompleteService:
await self.embedding_search.__aexit__(exc_type, exc, tb)
await self.conversation_helper.__aexit__(exc_type, exc, tb)
await self.conversation_chunk_helper.__aexit__(exc_type, exc, tb)
await self.bot_persona_helper.__aexit__(exc_type, exc, tb)
async def page_index_exists(self):
return await self.embedding_search.page_index_exists(False)
@ -96,6 +103,7 @@ class ChatCompleteService:
user_id: Optional[int] = None,
question_tokens: Optional[int] = None,
edit_message_id: Optional[str] = None,
bot_id: Optional[str] = None,
embedding_search: Optional[EmbeddingSearchArgs] = None,
) -> ChatCompleteServicePrepareResponse:
if user_id is not None:
@ -104,6 +112,7 @@ class ChatCompleteService:
self.user_id = user_id
self.question = question
self.conversation_start_time = int(time.time())
self.bot_id = bot_id or None
self.conversation_info = None
if conversation_id is not None:
@ -131,6 +140,17 @@ class ChatCompleteService:
# If the question is too long, we need to truncate it
raise web.HTTPRequestEntityTooLarge()
if self.conversation_info is not None:
self.bot_id = self.conversation_info.extra.get("bot_id") or "default"
bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id)
if bot_persona is None:
self.bot_id = "default"
bot_persona = await self.bot_persona_helper.find_by_bot_id(self.bot_id)
else:
self.chat_system_prompt = bot_persona.system_prompt
self.conversation_chunk = None
if self.conversation_info is not None:
chunk_id_list = await self.conversation_chunk_helper.get_chunk_id_list(self.conversation_id)
@ -199,12 +219,23 @@ class ChatCompleteService:
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk)
else:
# 创建新对话
# 默认聊天记录
init_message_data = []
if bot_persona is not None:
current_time = int(time.time())
for message in bot_persona.message_log:
message["id"] = utils.web.generate_uuid()
message["time"] = current_time
init_message_data.append(message)
title_info = self.embedding_search.title_index
self.conversation_info = ConversationModel(
user_id=self.user_id,
module="chatcomplete",
page_id=title_info.page_id,
rev_id=title_info.latest_rev_id,
extra={"bot_id": self.bot_id},
)
self.conversation_info = await self.conversation_helper.add(
self.conversation_info,
@ -212,7 +243,7 @@ class ChatCompleteService:
self.conversation_chunk = ConversationChunkModel(
conversation_id=self.conversation_info.id,
message_data=[],
message_data=init_message_data,
tokens=0,
)
self.conversation_chunk = await self.conversation_chunk_helper.add(
@ -264,16 +295,15 @@ class ChatCompleteService:
)
message_log.append({"role": "user", "content": doc_prompt})
bot_persona = self.conversation_info.extra.get("bot_persona") or "default"
system_prompt = await BotPersonaHelper.get_cached_system_prompt(self.dbs, bot_persona)
if system_prompt is None:
system_prompt = await BotPersonaHelper.get_cached_system_prompt(self.dbs, "default")
system_prompt = self.chat_system_prompt
if system_prompt is None:
system_prompt = utils.config.get_prompt("default", "system")
if system_prompt is None:
raise Exception("System prompt not found.")
system_prompt = utils.config.format_prompt(system_prompt)
# Start chat complete
if on_message is not None:
response = await self.openai_api.chat_complete_stream(

@ -6,9 +6,6 @@ from typing import Optional, TypedDict
import aiohttp
from config import Config
mw_api = Config.get("mw.api_endpoint", "https://www.isekai.cn/api.php")
request_proxy = Config.get("request.proxy", type=str, empty_is_none=True)
class MediaWikiApiException(Exception):
def __init__(self, info: str, code: Optional[str] = None) -> None:
super().__init__(info)
@ -55,6 +52,7 @@ class MediaWikiApi:
@staticmethod
def create():
mw_api = Config.get("mw.api_endpoint", "https://www.isekai.cn/api.php")
if MediaWikiApi.instance is None:
MediaWikiApi.instance = MediaWikiApi(mw_api)
@ -62,6 +60,7 @@ class MediaWikiApi:
def __init__(self, api_url: str):
self.api_url = api_url
self.request_proxy = Config.get("request.proxy", type=str, empty_is_none=True)
self.cookie_jar = aiohttp.CookieJar(unsafe=True)
self.login_identity = None
@ -77,7 +76,7 @@ class MediaWikiApi:
"titles": title,
"inprop": "url"
}
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
async with session.get(self.api_url, params=params, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -99,7 +98,7 @@ class MediaWikiApi:
"disabletoc": "true",
"disablelimitreport": "true",
}
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
async with session.get(self.api_url, params=params, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -115,7 +114,7 @@ class MediaWikiApi:
"meta": "siteinfo|userinfo",
"siprop": "general"
}
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
async with session.get(self.api_url, params=params, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -148,7 +147,7 @@ class MediaWikiApi:
if start_title is not None:
params["apfrom"] = start_title
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
async with session.get(self.api_url, params=params, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -172,7 +171,7 @@ class MediaWikiApi:
"formatversion": "2",
"meta": "userinfo"
}
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
async with session.get(self.api_url, params=params, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -188,7 +187,7 @@ class MediaWikiApi:
"meta": "tokens",
"type": token_type
}
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
async with session.get(self.api_url, params=params, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -210,7 +209,7 @@ class MediaWikiApi:
"lgpassword": password,
"lgtoken": token,
}
async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
async with session.post(self.api_url, data=post_data, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -242,7 +241,7 @@ class MediaWikiApi:
"namespace": 0,
"format": "json",
}
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
async with session.get(self.api_url, params=params, proxy=self.request_proxy) as resp:
data = await resp.json()
return data[1]
@ -257,7 +256,7 @@ class MediaWikiApi:
"format": "json",
"formatversion": "2",
}
async with session.get(self.api_url, params=params, proxy=request_proxy) as resp:
async with session.get(self.api_url, params=params, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
if data["error"]["code"] == "user-not-found":
@ -284,7 +283,7 @@ class MediaWikiApi:
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
async with session.post(self.api_url, data=post_data, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
print(data)
@ -310,7 +309,7 @@ class MediaWikiApi:
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
async with session.post(self.api_url, data=post_data, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
if data["error"]["code"] == "noenoughpoints":
@ -338,7 +337,7 @@ class MediaWikiApi:
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
async with session.post(self.api_url, data=post_data, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
@ -363,7 +362,7 @@ class MediaWikiApi:
}
# Filter out None values
post_data = {k: v for k, v in post_data.items() if v is not None}
async with session.post(self.api_url, data=post_data, proxy=request_proxy) as resp:
async with session.post(self.api_url, data=post_data, proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])

@ -21,16 +21,16 @@ class ChatCompleteResponse(TypedDict):
total_tokens: int
finish_reason: str
api_type = Config.get("chatcomplete.api_type", "openai", str)
request_proxy = Config.get("chatcomplete.request_proxy", type=str, empty_is_none=True)
class OpenAIApi:
@staticmethod
def create():
return OpenAIApi()
def __init__(self):
if api_type == "azure":
self.api_type = Config.get("chatcomplete.api_type", "openai", str)
self.request_proxy = Config.get("chatcomplete.request_proxy", type=str, empty_is_none=True)
if self.api_type == "azure":
self.api_url = Config.get("chatcomplete.azure.api_endpoint", type=str)
self.api_key = Config.get("chatcomplete.azure.key", type=str)
else:
@ -38,7 +38,7 @@ class OpenAIApi:
self.api_key = Config.get("chatcomplete.openai.key", type=str)
def build_header(self):
if api_type == "azure":
if self.api_type == "azure":
return {
"content-type": "application/json",
"accept": "application/json",
@ -52,8 +52,8 @@ class OpenAIApi:
}
def get_url(self, method: str):
if api_type == "azure":
deployments = Config.get("chatcomplete.azure.deployments", type=dict)
if self.api_type == "azure":
deployments = Config.get("chatcomplete.azure.deployments")
if method == "chat/completions":
return self.api_url + "/openai/deployments/" + deployments["chatcomplete"] + "/" + method
elif method == "embeddings":
@ -88,13 +88,13 @@ class OpenAIApi:
"input": text_list,
}
if api_type == "azure":
if self.api_type == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "text-embedding-ada-002"
if api_type == "azure":
if self.api_type == "azure":
# Azure api does not support batch
for index, text in enumerate(text_list):
retry_num = 0
@ -106,7 +106,7 @@ class OpenAIApi:
params=params,
json={"input": text},
timeout=30,
proxy=request_proxy) as resp:
proxy=self.request_proxy) as resp:
data = await resp.json()
@ -138,10 +138,13 @@ class OpenAIApi:
params=params,
json=post_data,
timeout=30,
proxy=request_proxy) as resp:
proxy=self.request_proxy) as resp:
data = await resp.json()
if "error" in data:
raise Exception(data["error"])
for one_data in data["data"]:
embedding = one_data["embedding"]
index = one_data["index"]
@ -186,7 +189,7 @@ class OpenAIApi:
"user": user,
}
if api_type == "azure":
if self.api_type == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "gpt-3.5-turbo"
@ -199,7 +202,7 @@ class OpenAIApi:
params=params,
json=post_data,
timeout=30,
proxy=api_type) as resp:
proxy=self.request_proxy) as resp:
data = await resp.json()
@ -244,7 +247,7 @@ class OpenAIApi:
"top_p": 0.95
}
if api_type == "azure":
if self.api_type == "azure":
params["api-version"] = "2023-05-15"
else:
post_data["model"] = "gpt-3.5-turbo"
@ -262,7 +265,7 @@ class OpenAIApi:
headers=self.build_header(),
params=params,
json=post_data,
proxy=request_proxy
proxy=self.request_proxy
) as session:
async for event in session:
"""

@ -1,3 +1,4 @@
import time
from config import Config
def get_prompt(name: str, type: str, params: dict = {}):
@ -16,3 +17,15 @@ def get_prompt(name: str, type: str, params: dict = {}):
return prompt
else:
return None
def format_prompt(prompt: str, params: dict = {}):
sys_params = {
"bot_name": Config.get("chatcomplete.bot_name", "ChatGPT"),
"current_date": time.strftime("%Y-%m-%d", time.localtime()),
"current_time": time.strftime("%H:%M:%S", time.localtime()),
}
for key in sys_params:
prompt = prompt.replace("{" + key + "}", sys_params[key])
for key in params:
prompt = prompt.replace("{" + key + "}", params[key])
return prompt
Loading…
Cancel
Save