增加选择机器人人格功能

master
落雨楓 2 years ago
parent 8ba74fa952
commit e761d4dbcb

@ -5,6 +5,7 @@ import time
import traceback
from api.controller.task.ChatCompleteTask import ChatCompleteTask
from api.model.base import clone_model
from api.model.chat_complete.bot_persona import BotPersonaHelper
from api.model.toolkit_ui.conversation import ConversationHelper
from local import noawait
from typing import Optional, Callable, TypedDict
@ -284,6 +285,63 @@ class ChatComplete:
"code": "chat-complete-error",
"message": err_msg
}, http_status=500, request=request)
@staticmethod
@utils.web.token_auth
async def get_persona_list(request: web.Request):
params = await utils.web.get_param(request, {
"category_id": {
"type": int,
"required": False,
},
"page": {
"type": int,
"required": False,
"default": 1,
}
})
category_id = params.get("category_id")
page = params.get("page")
db = await DatabaseService.create(request.app)
async with BotPersonaHelper(db) as bot_persona_helper:
persona_list = await bot_persona_helper.get_list(page=page, category_id=category_id)
page_count = await bot_persona_helper.get_page_count(category_id=category_id)
return await utils.web.api_response(1, {
"list": persona_list,
"page_count": page_count,
}, request=request)
@staticmethod
@utils.web.token_auth
async def get_persona_info(request: web.Request):
params = await utils.web.get_param(request, {
"id": {
"type": int,
},
"bot_id": {
"type": str,
}
})
persona_id = params.get("id")
bot_id = params.get("bot_id")
db = await DatabaseService.create(request.app)
async with BotPersonaHelper(db) as bot_persona_helper:
if persona_id is not None:
persona_info = await bot_persona_helper.find_by_id(persona_id)
elif bot_id is not None:
persona_info = await bot_persona_helper.find_by_bot_id(bot_id)
else:
return await utils.web.api_response(-1, error={
"code": "invalid-params",
"message": "Invalid params. Please specify id or bot_id."
}, http_status=400, request=request)
return await utils.web.api_response(1, persona_info, request=request)
@staticmethod
@utils.web.token_auth

@ -1,10 +1,15 @@
from __future__ import annotations
import math
from typing import Optional
import sqlalchemy
from api.model.base import BaseHelper, BaseModel
import sqlalchemy
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, Mapped
from sqlalchemy.orm import mapped_column, relationship, load_only, Mapped
from api.model.chat_complete.bot_persona_category import BotPersonaCategoryModel
from service.database import DatabaseService
class BotPersonaModel(BaseModel):
__tablename__ = "chat_complete_bot_persona"
@ -14,9 +19,12 @@ class BotPersonaModel(BaseModel):
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
category_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey(BotPersonaCategoryModel.id, ondelete="CASCADE", onupdate="CASCADE"), index=True)
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
updated_at: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
class BotPersonaHelper(BaseHelper):
async def add(self, obj: BotPersonaModel):
@ -30,14 +38,31 @@ class BotPersonaHelper(BaseHelper):
await self.session.commit()
return obj
async def get_list(self):
stmt = select(BotPersonaModel).with_only_columns([
BotPersonaModel.id,
BotPersonaModel.bot_name,
BotPersonaModel.bot_avatar,
BotPersonaModel.bot_description
])
async def get_list(self, page: Optional[int] = 1, page_size: Optional[int] = 20, category_id: Optional[int] = None):
offset_index = (page - 1) * page_size
stmt = select(BotPersonaModel) \
.options(load_only("id", "bot_id", "bot_name", "bot_avatar", "bot_description", "updated_at")) \
.order_by(BotPersonaModel.updated_at.desc()) \
.offset(offset_index) \
.limit(page_size)
if category_id is not None:
stmt = stmt.where(BotPersonaModel.category_id == category_id)
return await self.session.scalars(stmt)
async def get_page_count(self, page_size = 50, category_id: Optional[int] = None):
stmt = select(sqlalchemy.func.count()).select_from(BotPersonaModel)
if category_id is not None:
stmt = stmt.where(BotPersonaModel.category_id == category_id)
item_count = await self.session.scalar(stmt)
if item_count is None:
item_count = 0
return int(math.ceil(item_count / page_size))
async def find_by_id(self, id: int):
stmt = select(BotPersonaModel).where(BotPersonaModel.id == id)
@ -45,4 +70,13 @@ class BotPersonaHelper(BaseHelper):
async def find_by_bot_id(self, bot_id: str):
stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id)
return await self.session.scalar(stmt)
return await self.session.scalar(stmt)
async def get_system_prompt(self, bot_id: str) -> str | None:
stmt = select(BotPersonaModel.system_prompt).where(BotPersonaModel.bot_id == bot_id)
return await self.session.scalar(stmt)
@staticmethod
async def get_cached_system_prompt(dbs: DatabaseService, bot_id: str) -> str | None:
async with BotPersonaHelper(dbs) as bot_persona_helper:
return await bot_persona_helper.get_system_prompt(bot_id)

@ -15,7 +15,7 @@ class PageTitleModel(BaseModel):
id: Mapped[int] = mapped_column(
sqlalchemy.Integer, primary_key=True, autoincrement=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, unique=True)
title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True)
updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)

@ -36,4 +36,6 @@ def init(app: web.Application):
web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete),
web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream),
web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost),
web.route('*', '/chatcomplete/persona/list', ChatComplete.get_persona_list),
web.route('*', '/chatcomplete/persona/info', ChatComplete.get_persona_info),
])

@ -7,7 +7,7 @@ class Config:
@staticmethod
def load_config(file):
with open(file, "r") as f:
with open(file, "r", encoding="utf-8") as f:
Config.values = toml.load(f)
@staticmethod

@ -1,5 +1,6 @@
import asyncio
from noawait import NoAwaitPool
debug = False
loop = asyncio.new_event_loop()
noawait = NoAwaitPool(loop)

@ -2,7 +2,7 @@ from local import loop, noawait
from aiohttp import web
from config import Config
import toml
import local
import api.route
import utils.web
from service.database import DatabaseService
@ -10,8 +10,10 @@ from service.mediawiki_api import MediaWikiApi
# Auto create Table
from api.model.base import BaseModel
from api.model.toolkit_ui.page_title import PageTitleModel as _
from api.model.toolkit_ui.conversation import ConversationModel as _
from api.model.chat_complete.conversation import ConversationChunkModel as _
from api.model.chat_complete.bot_persona_category import BotPersonaCategoryModel as _
from api.model.chat_complete.bot_persona import BotPersonaModel as _
from api.model.embedding_search.title_collection import TitleCollectionModel as _
from api.model.embedding_search.title_index import TitleIndexModel as _
@ -54,6 +56,8 @@ async def stop_noawait_pool(app: web.Application):
if __name__ == '__main__':
Config.load_config("config.toml")
local.debug = Config.get("server.debug", False, bool)
app = web.Application()
if Config.get("database.host"):

@ -4,6 +4,7 @@ import traceback
from typing import Optional, Tuple, TypedDict
import sqlalchemy
from api.model.chat_complete.bot_persona import BotPersonaHelper
from api.model.chat_complete.conversation import (
ConversationChunkHelper,
ConversationChunkModel,
@ -263,7 +264,15 @@ class ChatCompleteService:
)
message_log.append({"role": "user", "content": doc_prompt})
system_prompt = utils.config.get_prompt("chat", "system")
bot_persona = self.conversation_info.extra.get("bot_persona") or "default"
system_prompt = await BotPersonaHelper.get_cached_system_prompt(self.dbs, bot_persona)
if system_prompt is None:
system_prompt = await BotPersonaHelper.get_cached_system_prompt(self.dbs, "default")
if system_prompt is None:
system_prompt = utils.config.get_prompt("default", "system")
if system_prompt is None:
raise Exception("System prompt not found.")
# Start chat complete
if on_message is not None:

@ -40,11 +40,10 @@ class DatabaseService:
async def init(self):
db_conf = Config.get("database")
debug_mode = Config.get("debug", False, bool)
loop = local.loop
self.pool = await asyncpg.create_pool(**db_conf, loop=loop)
engine = create_async_engine(get_dsn(), echo=debug_mode)
engine = create_async_engine(get_dsn(), echo=local.debug)
self.engine = engine
self.create_session = async_sessionmaker(engine, expire_on_commit=False)

@ -7,4 +7,4 @@ sys.path.append(root_path)
from config import Config
Config.load_config(root_path + "/config.toml")
Config.set("debug", True)
Config.set("server.debug", True)

@ -0,0 +1,15 @@
import asyncio
import base
from sqlalchemy import select
from api.model.embedding_search.title_index import TitleIndexModel
import local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService
async def main():
pass
if __name__ == '__main__':
local.loop.run_until_complete(main())
Loading…
Cancel
Save