You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

109 lines
3.8 KiB
Python

from __future__ import annotations
import math
from typing import Optional
import sqlalchemy
from api.model.base import BaseHelper, BaseModel
import sqlalchemy
from sqlalchemy import select, update
from sqlalchemy.orm import mapped_column, relationship, load_only, Mapped
from api.model.chat_complete.bot_persona_category import BotPersonaCategoryModel
from service.database import DatabaseService
class BotPersonaModel(BaseModel):
__tablename__ = "chat_complete_bot_persona"
id: Mapped[int] = mapped_column(
sqlalchemy.Integer, primary_key=True, autoincrement=True
)
bot_id: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
category_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey(
BotPersonaCategoryModel.id, ondelete="CASCADE", onupdate="CASCADE"
),
index=True,
)
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
updated_at: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
class BotPersonaHelper(BaseHelper):
async def add(self, obj: BotPersonaModel):
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj
async def update(self, obj: BotPersonaModel):
obj = await self.session.merge(obj)
await self.session.commit()
return obj
async def get_list(
self,
page: Optional[int] = 1,
page_size: Optional[int] = 20,
category_id: Optional[int] = None,
):
offset_index = (page - 1) * page_size
stmt = (
select(BotPersonaModel)
.options(
load_only(
BotPersonaModel.id,
BotPersonaModel.bot_id,
BotPersonaModel.bot_name,
BotPersonaModel.bot_avatar,
BotPersonaModel.bot_description,
BotPersonaModel.updated_at,
)
)
.order_by(BotPersonaModel.updated_at.desc())
.offset(offset_index)
.limit(page_size)
)
if category_id is not None:
stmt = stmt.where(BotPersonaModel.category_id == category_id)
return await self.session.scalars(stmt)
async def get_page_count(self, page_size=50, category_id: Optional[int] = None):
stmt = select(sqlalchemy.func.count()).select_from(BotPersonaModel)
if category_id is not None:
stmt = stmt.where(BotPersonaModel.category_id == category_id)
item_count = await self.session.scalar(stmt)
if item_count is None:
item_count = 0
return int(math.ceil(item_count / page_size))
async def find_by_id(self, id: int):
stmt = select(BotPersonaModel).where(BotPersonaModel.id == id)
return await self.session.scalar(stmt)
async def find_by_bot_id(self, bot_id: str):
stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id)
return await self.session.scalar(stmt)
async def get_system_prompt(self, bot_id: str) -> str | None:
stmt = select(BotPersonaModel.system_prompt).where(
BotPersonaModel.bot_id == bot_id
)
return await self.session.scalar(stmt)
@staticmethod
async def get_cached_system_prompt(dbs: DatabaseService, bot_id: str) -> str | None:
async with BotPersonaHelper(dbs) as bot_persona_helper:
return await bot_persona_helper.get_system_prompt(bot_id)