from __future__ import annotations import math from typing import Optional import sqlalchemy from server.model.base import BaseHelper, BaseModel import sqlalchemy from sqlalchemy import select from sqlalchemy.orm import mapped_column, load_only, Mapped from service.database import DatabaseService class BotPersonaModel(BaseModel): __tablename__ = "chat_complete_bot_persona" id: Mapped[int] = mapped_column( sqlalchemy.Integer, primary_key=True, autoincrement=True ) bot_id: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True) bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) api_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) model_id: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True, index=True) model_name: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) system_prompt: Mapped[str] = mapped_column(sqlalchemy.String) message_log: Mapped[list] = mapped_column(sqlalchemy.JSON, nullable=True) default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True) cost_fixed: Mapped[float] = mapped_column(sqlalchemy.Float, nullable=True) cost_fixed_tokens: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) cost_per_token: Mapped[float] = mapped_column(sqlalchemy.Float, nullable=True) order: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, nullable=True) class BotPersonaHelper(BaseHelper): async def add(self, obj: BotPersonaModel): self.session.add(obj) await self.session.commit() await self.session.refresh(obj) return obj async def update(self, obj: BotPersonaModel): obj = await self.session.merge(obj) await self.session.commit() return obj async def get_list( self, page: Optional[int] = 1, page_size: Optional[int] = 20 ): offset_index = (page - 1) * page_size stmt = ( select(BotPersonaModel) .options( load_only( BotPersonaModel.id, BotPersonaModel.bot_id, BotPersonaModel.bot_name, BotPersonaModel.bot_avatar, BotPersonaModel.bot_description, BotPersonaModel.model_id, BotPersonaModel.model_name, BotPersonaModel.cost_fixed, BotPersonaModel.order, ) ) .order_by(BotPersonaModel.order.desc()) .offset(offset_index) .limit(page_size) ) return await self.session.scalars(stmt) async def get_page_count(self, page_size=50): stmt = select(sqlalchemy.func.count()).select_from(BotPersonaModel) item_count = await self.session.scalar(stmt) if item_count is None: item_count = 0 return int(math.ceil(item_count / page_size)) async def find_by_id(self, id: int): stmt = select(BotPersonaModel).where(BotPersonaModel.id == id) return await self.session.scalar(stmt) async def find_by_bot_id(self, bot_id: str): stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id) return await self.session.scalar(stmt) async def get_system_prompt(self, bot_id: str) -> str | None: stmt = select(BotPersonaModel.system_prompt).where( BotPersonaModel.bot_id == bot_id ) return await self.session.scalar(stmt) @staticmethod async def get_cached_system_prompt(dbs: DatabaseService, bot_id: str) -> str | None: async with BotPersonaHelper(dbs) as bot_persona_helper: return await bot_persona_helper.get_system_prompt(bot_id)