diff --git a/api/controller/ChatComplete.py b/api/controller/ChatComplete.py index ef2c949..5639f6c 100644 --- a/api/controller/ChatComplete.py +++ b/api/controller/ChatComplete.py @@ -5,6 +5,7 @@ import time import traceback from api.controller.task.ChatCompleteTask import ChatCompleteTask from api.model.base import clone_model +from api.model.chat_complete.bot_persona import BotPersonaHelper from api.model.toolkit_ui.conversation import ConversationHelper from local import noawait from typing import Optional, Callable, TypedDict @@ -284,6 +285,63 @@ class ChatComplete: "code": "chat-complete-error", "message": err_msg }, http_status=500, request=request) + + @staticmethod + @utils.web.token_auth + async def get_persona_list(request: web.Request): + params = await utils.web.get_param(request, { + "category_id": { + "type": int, + "required": False, + }, + "page": { + "type": int, + "required": False, + "default": 1, + } + }) + + category_id = params.get("category_id") + page = params.get("page") + + db = await DatabaseService.create(request.app) + async with BotPersonaHelper(db) as bot_persona_helper: + persona_list = await bot_persona_helper.get_list(page=page, category_id=category_id) + page_count = await bot_persona_helper.get_page_count(category_id=category_id) + + return await utils.web.api_response(1, { + "list": persona_list, + "page_count": page_count, + }, request=request) + + @staticmethod + @utils.web.token_auth + async def get_persona_info(request: web.Request): + params = await utils.web.get_param(request, { + "id": { + "type": int, + }, + "bot_id": { + "type": str, + } + }) + + persona_id = params.get("id") + bot_id = params.get("bot_id") + + db = await DatabaseService.create(request.app) + async with BotPersonaHelper(db) as bot_persona_helper: + if persona_id is not None: + persona_info = await bot_persona_helper.find_by_id(persona_id) + elif bot_id is not None: + persona_info = await bot_persona_helper.find_by_bot_id(bot_id) + else: + return await utils.web.api_response(-1, error={ + "code": "invalid-params", + "message": "Invalid params. Please specify id or bot_id." + }, http_status=400, request=request) + + return await utils.web.api_response(1, persona_info, request=request) @staticmethod @utils.web.token_auth diff --git a/api/model/chat_complete/bot_persona.py b/api/model/chat_complete/bot_persona.py index f7362b1..9062d0d 100644 --- a/api/model/chat_complete/bot_persona.py +++ b/api/model/chat_complete/bot_persona.py @@ -1,10 +1,15 @@ from __future__ import annotations +import math +from typing import Optional import sqlalchemy from api.model.base import BaseHelper, BaseModel import sqlalchemy from sqlalchemy import select, update -from sqlalchemy.orm import mapped_column, relationship, Mapped +from sqlalchemy.orm import mapped_column, relationship, load_only, Mapped + +from api.model.chat_complete.bot_persona_category import BotPersonaCategoryModel +from service.database import DatabaseService class BotPersonaModel(BaseModel): __tablename__ = "chat_complete_bot_persona" @@ -14,9 +19,12 @@ class BotPersonaModel(BaseModel): bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) + category_id: Mapped[int] = mapped_column( + sqlalchemy.ForeignKey(BotPersonaCategoryModel.id, ondelete="CASCADE", onupdate="CASCADE"), index=True) system_prompt: Mapped[str] = mapped_column(sqlalchemy.String) message_log: Mapped[list] = mapped_column(sqlalchemy.JSON) default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) + updated_at: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) class BotPersonaHelper(BaseHelper): async def add(self, obj: BotPersonaModel): @@ -30,14 +38,31 @@ class BotPersonaHelper(BaseHelper): await self.session.commit() return obj - async def get_list(self): - stmt = select(BotPersonaModel).with_only_columns([ - BotPersonaModel.id, - BotPersonaModel.bot_name, - BotPersonaModel.bot_avatar, - BotPersonaModel.bot_description - ]) + async def get_list(self, page: Optional[int] = 1, page_size: Optional[int] = 20, category_id: Optional[int] = None): + offset_index = (page - 1) * page_size + + stmt = select(BotPersonaModel) \ + .options(load_only("id", "bot_id", "bot_name", "bot_avatar", "bot_description", "updated_at")) \ + .order_by(BotPersonaModel.updated_at.desc()) \ + .offset(offset_index) \ + .limit(page_size) + + if category_id is not None: + stmt = stmt.where(BotPersonaModel.category_id == category_id) + return await self.session.scalars(stmt) + + async def get_page_count(self, page_size = 50, category_id: Optional[int] = None): + stmt = select(sqlalchemy.func.count()).select_from(BotPersonaModel) + + if category_id is not None: + stmt = stmt.where(BotPersonaModel.category_id == category_id) + + item_count = await self.session.scalar(stmt) + if item_count is None: + item_count = 0 + + return int(math.ceil(item_count / page_size)) async def find_by_id(self, id: int): stmt = select(BotPersonaModel).where(BotPersonaModel.id == id) @@ -45,4 +70,13 @@ class BotPersonaHelper(BaseHelper): async def find_by_bot_id(self, bot_id: str): stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id) - return await self.session.scalar(stmt) \ No newline at end of file + return await self.session.scalar(stmt) + + async def get_system_prompt(self, bot_id: str) -> str | None: + stmt = select(BotPersonaModel.system_prompt).where(BotPersonaModel.bot_id == bot_id) + return await self.session.scalar(stmt) + + @staticmethod + async def get_cached_system_prompt(dbs: DatabaseService, bot_id: str) -> str | None: + async with BotPersonaHelper(dbs) as bot_persona_helper: + return await bot_persona_helper.get_system_prompt(bot_id) \ No newline at end of file diff --git a/api/model/toolkit_ui/page_title.py b/api/model/toolkit_ui/page_title.py index a395e43..cedb811 100644 --- a/api/model/toolkit_ui/page_title.py +++ b/api/model/toolkit_ui/page_title.py @@ -15,7 +15,7 @@ class PageTitleModel(BaseModel): id: Mapped[int] = mapped_column( sqlalchemy.Integer, primary_key=True, autoincrement=True) - page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) + page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, unique=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True) updated_at: Mapped[int] = mapped_column(sqlalchemy.BigInteger, index=True) diff --git a/api/route.py b/api/route.py index cb45799..a87fa1a 100644 --- a/api/route.py +++ b/api/route.py @@ -36,4 +36,6 @@ def init(app: web.Application): web.route('POST', '/chatcomplete/message', ChatComplete.start_chat_complete), web.route('GET', '/chatcomplete/message/stream', ChatComplete.chat_complete_stream), web.route('POST', '/chatcomplete/get_point_cost', ChatComplete.get_point_cost), + web.route('*', '/chatcomplete/persona/list', ChatComplete.get_persona_list), + web.route('*', '/chatcomplete/persona/info', ChatComplete.get_persona_info), ]) diff --git a/config.py b/config.py index 5d932d4..dd99ab5 100644 --- a/config.py +++ b/config.py @@ -7,7 +7,7 @@ class Config: @staticmethod def load_config(file): - with open(file, "r") as f: + with open(file, "r", encoding="utf-8") as f: Config.values = toml.load(f) @staticmethod diff --git a/local.py b/local.py index 3862e5d..9febcce 100644 --- a/local.py +++ b/local.py @@ -1,5 +1,6 @@ import asyncio from noawait import NoAwaitPool +debug = False loop = asyncio.new_event_loop() noawait = NoAwaitPool(loop) \ No newline at end of file diff --git a/main.py b/main.py index e01f52a..9f2cfb8 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ from local import loop, noawait from aiohttp import web from config import Config -import toml +import local import api.route import utils.web from service.database import DatabaseService @@ -10,8 +10,10 @@ from service.mediawiki_api import MediaWikiApi # Auto create Table from api.model.base import BaseModel +from api.model.toolkit_ui.page_title import PageTitleModel as _ from api.model.toolkit_ui.conversation import ConversationModel as _ from api.model.chat_complete.conversation import ConversationChunkModel as _ +from api.model.chat_complete.bot_persona_category import BotPersonaCategoryModel as _ from api.model.chat_complete.bot_persona import BotPersonaModel as _ from api.model.embedding_search.title_collection import TitleCollectionModel as _ from api.model.embedding_search.title_index import TitleIndexModel as _ @@ -54,6 +56,8 @@ async def stop_noawait_pool(app: web.Application): if __name__ == '__main__': Config.load_config("config.toml") + local.debug = Config.get("server.debug", False, bool) + app = web.Application() if Config.get("database.host"): diff --git a/service/chat_complete.py b/service/chat_complete.py index f5012c7..5dcdabd 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -4,6 +4,7 @@ import traceback from typing import Optional, Tuple, TypedDict import sqlalchemy +from api.model.chat_complete.bot_persona import BotPersonaHelper from api.model.chat_complete.conversation import ( ConversationChunkHelper, ConversationChunkModel, @@ -263,7 +264,15 @@ class ChatCompleteService: ) message_log.append({"role": "user", "content": doc_prompt}) - system_prompt = utils.config.get_prompt("chat", "system") + bot_persona = self.conversation_info.extra.get("bot_persona") or "default" + system_prompt = await BotPersonaHelper.get_cached_system_prompt(self.dbs, bot_persona) + if system_prompt is None: + system_prompt = await BotPersonaHelper.get_cached_system_prompt(self.dbs, "default") + if system_prompt is None: + system_prompt = utils.config.get_prompt("default", "system") + + if system_prompt is None: + raise Exception("System prompt not found.") # Start chat complete if on_message is not None: diff --git a/service/database.py b/service/database.py index 22e9d93..807ece7 100644 --- a/service/database.py +++ b/service/database.py @@ -40,11 +40,10 @@ class DatabaseService: async def init(self): db_conf = Config.get("database") - debug_mode = Config.get("debug", False, bool) loop = local.loop self.pool = await asyncpg.create_pool(**db_conf, loop=loop) - engine = create_async_engine(get_dsn(), echo=debug_mode) + engine = create_async_engine(get_dsn(), echo=local.debug) self.engine = engine self.create_session = async_sessionmaker(engine, expire_on_commit=False) \ No newline at end of file diff --git a/test/base.py b/test/base.py index 6d40567..113ba39 100644 --- a/test/base.py +++ b/test/base.py @@ -7,4 +7,4 @@ sys.path.append(root_path) from config import Config Config.load_config(root_path + "/config.toml") -Config.set("debug", True) \ No newline at end of file +Config.set("server.debug", True) \ No newline at end of file diff --git a/test/create_token.py b/test/create_token.py new file mode 100644 index 0000000..1e63e63 --- /dev/null +++ b/test/create_token.py @@ -0,0 +1,15 @@ +import asyncio +import base + +from sqlalchemy import select +from api.model.embedding_search.title_index import TitleIndexModel +import local +from service.database import DatabaseService + +from service.embedding_search import EmbeddingSearchService + +async def main(): + pass + +if __name__ == '__main__': + local.loop.run_until_complete(main()) \ No newline at end of file