增加选择机器人人格功能

master
落雨楓 2 years ago
parent 8ba74fa952
commit e761d4dbcb

@ -5,6 +5,7 @@ import time
import traceback import traceback
from api.controller.task.ChatCompleteTask import ChatCompleteTask from api.controller.task.ChatCompleteTask import ChatCompleteTask
from api.model.base import clone_model from api.model.base import clone_model
from api.model.chat_complete.bot_persona import BotPersonaHelper
from api.model.toolkit_ui.conversation import ConversationHelper from api.model.toolkit_ui.conversation import ConversationHelper
from local import noawait from local import noawait
from typing import Optional, Callable, TypedDict from typing import Optional, Callable, TypedDict
@ -284,6 +285,63 @@ class ChatComplete:
"code": "chat-complete-error", "code": "chat-complete-error",
"message": err_msg "message": err_msg
}, http_status=500, request=request) }, http_status=500, request=request)
@staticmethod
@utils.web.token_auth
async def get_persona_list(request: web.Request):
params = await utils.web.get_param(request, {
"category_id": {
"type": int,
"required": False,
},
"page": {
"type": int,
"required": False,
"default": 1,
}
})
category_id = params.get("category_id")
page = params.get("page")
db = await DatabaseService.create(request.app)
async with BotPersonaHelper(db) as bot_persona_helper:
persona_list = await bot_persona_helper.get_list(page=page, category_id=category_id)
page_count = await bot_persona_helper.get_page_count(category_id=category_id)
return await utils.web.api_response(1, {
"list": persona_list,
"page_count": page_count,
}, request=request)
@staticmethod
@utils.web.token_auth
async def get_persona_info(request: web.Request):
params = await utils.web.get_param(request, {
"id": {
"type": int,
},
"bot_id": {
"type": str,
}
})
persona_id = params.get("id")
bot_id = params.get("bot_id")
db = await DatabaseService.create(request.app)
async with BotPersonaHelper(db) as bot_persona_helper:
if persona_id is not None:
persona_info = await bot_persona_helper.find_by_id(persona_id)
elif bot_id is not None:
persona_info = await bot_persona_helper.find_by_bot_id(bot_id)
else:
return await utils.web.api_response(-1, error={
"code": "invalid-params",
"message": "Invalid params. Please specify id or bot_id."
}, http_status=400, request=request)
return await utils.web.api_response(1, persona_info, request=request)
@staticmethod @staticmethod
@utils.web.token_auth @utils.web.token_auth

@ -1,10 +1,15 @@
from __future__ import annotations from __future__ import annotations
import math
from typing import Optional
import sqlalchemy import sqlalchemy
from api.model.base import BaseHelper, BaseModel from api.model.base import BaseHelper, BaseModel
import sqlalchemy import sqlalchemy
from sqlalchemy import select, update from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, Mapped from sqlalchemy.orm import mapped_column, relationship, load_only, Mapped
from api.model.chat_complete.bot_persona_category import BotPersonaCategoryModel
from service.database import DatabaseService
class BotPersonaModel(BaseModel): class BotPersonaModel(BaseModel):
__tablename__ = "chat_complete_bot_persona" __tablename__ = "chat_complete_bot_persona"
@ -14,9 +19,12 @@ class BotPersonaModel(BaseModel):
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
category_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey(BotPersonaCategoryModel.id, ondelete="CASCADE", onupdate="CASCADE"), index=True)
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String) system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON) message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
updated_at: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
class BotPersonaHelper(BaseHelper): class BotPersonaHelper(BaseHelper):
async def add(self, obj: BotPersonaModel): async def add(self, obj: BotPersonaModel):
@ -30,14 +38,31 @@ class BotPersonaHelper(BaseHelper):
await self.session.commit() await self.session.commit()
return obj return obj
async def get_list(self): async def get_list(self, page: Optional[int] = 1, page_size: Optional[int] = 20, category_id: Optional[int] = None):
stmt = select(BotPersonaModel).with_only_columns([ offset_index = (page - 1) * page_size
BotPersonaModel.id,
BotPersonaModel.bot_name, stmt = select(BotPersonaModel) \
BotPersonaModel.bot_avatar, .options(load_only("id", "bot_id", "bot_name", "bot_avatar", "bot_description", "updated_at")) \
BotPersonaModel.bot_description .order_by(BotPersonaModel.updated_at.desc()) \
]) .offset(offset_index) \
.limit(page_size)
if category_id is not None:
stmt = stmt.where(BotPersonaModel.category_id == category_id)
return await self.session.scalars(stmt) return await self.session.scalars(stmt)
async def get_page_count(self, page_size = 50, category_id: Optional[int] = None):
stmt = select(sqlalchemy.func.count()).select_from(BotPersonaModel)
if category_id is not None:
stmt = stmt.where(BotPersonaModel.category_id == category_id)
item_count = await self.session.scalar(stmt)
if item_count is None:
item_count = 0
return int(math.ceil(item_count / page_size))
async def find_by_id(self, id: int): async def find_by_id(self, id: int):
stmt = select(BotPersonaModel).where(BotPersonaModel.id == id) stmt = select(BotPersonaModel).where(BotPersonaModel.id == id)
@ -45,4 +70,13 @@ class BotPersonaHelper(BaseHelper):
async def find_by_bot_id(self, bot_id: str): async def find_by_bot_id(self, bot_id: str):
stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id) stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id)
return await self.session.scalar(stmt) return await self.session.scalar(stmt)
async def get_system_prompt(self, bot_id: str) -> str | None:
stmt = select(BotPersonaModel.system_prompt).where(BotPersonaModel.bot_id == bot_id)
return await self.session.scalar(stmt)
@staticmethod
async def get_cached_system_prompt(dbs: DatabaseService, bot_id: str) -> str | None:
async with BotPersonaHelper(dbs) as bot_persona_helper:
return await bot_persona_helper.get_system_prompt(bot_id)

@ -15,7 +15,7 @@ class PageTitleModel(BaseModel):
id: Mapped[int] = mapped_column( id: Mapped[int] = mapped_column(
sqlalchemy.Integer, primary_key=True, autoincrement=True) sqlalchemy.Integer, primary_key=True, autoincrement=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, unique=True)
title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True)
updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True) updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True)

@ -36,4 +36,6 @@ def init(app: web.Application):
web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete), web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete),
web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream), web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream),
web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost), web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost),
web.route('*', '/chatcomplete/persona/list', ChatComplete.get_persona_list),
web.route('*', '/chatcomplete/persona/info', ChatComplete.get_persona_info),
]) ])

@ -7,7 +7,7 @@ class Config:
@staticmethod @staticmethod
def load_config(file): def load_config(file):
with open(file, "r") as f: with open(file, "r", encoding="utf-8") as f:
Config.values = toml.load(f) Config.values = toml.load(f)
@staticmethod @staticmethod

@ -1,5 +1,6 @@
import asyncio import asyncio
from noawait import NoAwaitPool from noawait import NoAwaitPool
debug = False
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
noawait = NoAwaitPool(loop) noawait = NoAwaitPool(loop)

@ -2,7 +2,7 @@ from local import loop, noawait
from aiohttp import web from aiohttp import web
from config import Config from config import Config
import toml import local
import api.route import api.route
import utils.web import utils.web
from service.database import DatabaseService from service.database import DatabaseService
@ -10,8 +10,10 @@ from service.mediawiki_api import MediaWikiApi
# Auto create Table # Auto create Table
from api.model.base import BaseModel from api.model.base import BaseModel
from api.model.toolkit_ui.page_title import PageTitleModel as _
from api.model.toolkit_ui.conversation import ConversationModel as _ from api.model.toolkit_ui.conversation import ConversationModel as _
from api.model.chat_complete.conversation import ConversationChunkModel as _ from api.model.chat_complete.conversation import ConversationChunkModel as _
from api.model.chat_complete.bot_persona_category import BotPersonaCategoryModel as _
from api.model.chat_complete.bot_persona import BotPersonaModel as _ from api.model.chat_complete.bot_persona import BotPersonaModel as _
from api.model.embedding_search.title_collection import TitleCollectionModel as _ from api.model.embedding_search.title_collection import TitleCollectionModel as _
from api.model.embedding_search.title_index import TitleIndexModel as _ from api.model.embedding_search.title_index import TitleIndexModel as _
@ -54,6 +56,8 @@ async def stop_noawait_pool(app: web.Application):
if __name__ == '__main__': if __name__ == '__main__':
Config.load_config("config.toml") Config.load_config("config.toml")
local.debug = Config.get("server.debug", False, bool)
app = web.Application() app = web.Application()
if Config.get("database.host"): if Config.get("database.host"):

@ -4,6 +4,7 @@ import traceback
from typing import Optional, Tuple, TypedDict from typing import Optional, Tuple, TypedDict
import sqlalchemy import sqlalchemy
from api.model.chat_complete.bot_persona import BotPersonaHelper
from api.model.chat_complete.conversation import ( from api.model.chat_complete.conversation import (
ConversationChunkHelper, ConversationChunkHelper,
ConversationChunkModel, ConversationChunkModel,
@ -263,7 +264,15 @@ class ChatCompleteService:
) )
message_log.append({"role": "user", "content": doc_prompt}) message_log.append({"role": "user", "content": doc_prompt})
system_prompt = utils.config.get_prompt("chat", "system") bot_persona = self.conversation_info.extra.get("bot_persona") or "default"
system_prompt = await BotPersonaHelper.get_cached_system_prompt(self.dbs, bot_persona)
if system_prompt is None:
system_prompt = await BotPersonaHelper.get_cached_system_prompt(self.dbs, "default")
if system_prompt is None:
system_prompt = utils.config.get_prompt("default", "system")
if system_prompt is None:
raise Exception("System prompt not found.")
# Start chat complete # Start chat complete
if on_message is not None: if on_message is not None:

@ -40,11 +40,10 @@ class DatabaseService:
async def init(self): async def init(self):
db_conf = Config.get("database") db_conf = Config.get("database")
debug_mode = Config.get("debug", False, bool)
loop = local.loop loop = local.loop
self.pool = await asyncpg.create_pool(**db_conf, loop=loop) self.pool = await asyncpg.create_pool(**db_conf, loop=loop)
engine = create_async_engine(get_dsn(), echo=debug_mode) engine = create_async_engine(get_dsn(), echo=local.debug)
self.engine = engine self.engine = engine
self.create_session = async_sessionmaker(engine, expire_on_commit=False) self.create_session = async_sessionmaker(engine, expire_on_commit=False)

@ -7,4 +7,4 @@ sys.path.append(root_path)
from config import Config from config import Config
Config.load_config(root_path + "/config.toml") Config.load_config(root_path + "/config.toml")
Config.set("debug", True) Config.set("server.debug", True)

@ -0,0 +1,15 @@
import asyncio
import base
from sqlalchemy import select
from api.model.embedding_search.title_index import TitleIndexModel
import local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService
async def main():
pass
if __name__ == '__main__':
local.loop.run_until_complete(main())
Loading…
Cancel
Save