|
|
@ -1,10 +1,15 @@
|
|
|
|
from __future__ import annotations
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
|
|
|
from typing import Optional
|
|
|
|
import sqlalchemy
|
|
|
|
import sqlalchemy
|
|
|
|
from api.model.base import BaseHelper, BaseModel
|
|
|
|
from api.model.base import BaseHelper, BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
import sqlalchemy
|
|
|
|
import sqlalchemy
|
|
|
|
from sqlalchemy import select, update
|
|
|
|
from sqlalchemy import select, update
|
|
|
|
from sqlalchemy.orm import mapped_column, relationship, Mapped
|
|
|
|
from sqlalchemy.orm import mapped_column, relationship, load_only, Mapped
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from api.model.chat_complete.bot_persona_category import BotPersonaCategoryModel
|
|
|
|
|
|
|
|
from service.database import DatabaseService
|
|
|
|
|
|
|
|
|
|
|
|
class BotPersonaModel(BaseModel):
|
|
|
|
class BotPersonaModel(BaseModel):
|
|
|
|
__tablename__ = "chat_complete_bot_persona"
|
|
|
|
__tablename__ = "chat_complete_bot_persona"
|
|
|
@ -14,9 +19,12 @@ class BotPersonaModel(BaseModel):
|
|
|
|
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
|
|
|
|
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
|
|
|
|
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
|
|
|
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
|
|
|
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
|
|
|
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
|
|
|
|
|
|
|
category_id: Mapped[int] = mapped_column(
|
|
|
|
|
|
|
|
sqlalchemy.ForeignKey(BotPersonaCategoryModel.id, ondelete="CASCADE", onupdate="CASCADE"), index=True)
|
|
|
|
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
|
|
|
|
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
|
|
|
|
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
|
|
|
|
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
|
|
|
|
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
|
|
|
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
|
|
|
|
|
|
|
updated_at: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
|
|
|
|
|
|
|
|
|
|
|
class BotPersonaHelper(BaseHelper):
|
|
|
|
class BotPersonaHelper(BaseHelper):
|
|
|
|
async def add(self, obj: BotPersonaModel):
|
|
|
|
async def add(self, obj: BotPersonaModel):
|
|
|
@ -30,15 +38,32 @@ class BotPersonaHelper(BaseHelper):
|
|
|
|
await self.session.commit()
|
|
|
|
await self.session.commit()
|
|
|
|
return obj
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
|
|
async def get_list(self):
|
|
|
|
async def get_list(self, page: Optional[int] = 1, page_size: Optional[int] = 20, category_id: Optional[int] = None):
|
|
|
|
stmt = select(BotPersonaModel).with_only_columns([
|
|
|
|
offset_index = (page - 1) * page_size
|
|
|
|
BotPersonaModel.id,
|
|
|
|
|
|
|
|
BotPersonaModel.bot_name,
|
|
|
|
stmt = select(BotPersonaModel) \
|
|
|
|
BotPersonaModel.bot_avatar,
|
|
|
|
.options(load_only("id", "bot_id", "bot_name", "bot_avatar", "bot_description", "updated_at")) \
|
|
|
|
BotPersonaModel.bot_description
|
|
|
|
.order_by(BotPersonaModel.updated_at.desc()) \
|
|
|
|
])
|
|
|
|
.offset(offset_index) \
|
|
|
|
|
|
|
|
.limit(page_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if category_id is not None:
|
|
|
|
|
|
|
|
stmt = stmt.where(BotPersonaModel.category_id == category_id)
|
|
|
|
|
|
|
|
|
|
|
|
return await self.session.scalars(stmt)
|
|
|
|
return await self.session.scalars(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_page_count(self, page_size = 50, category_id: Optional[int] = None):
|
|
|
|
|
|
|
|
stmt = select(sqlalchemy.func.count()).select_from(BotPersonaModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if category_id is not None:
|
|
|
|
|
|
|
|
stmt = stmt.where(BotPersonaModel.category_id == category_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
item_count = await self.session.scalar(stmt)
|
|
|
|
|
|
|
|
if item_count is None:
|
|
|
|
|
|
|
|
item_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return int(math.ceil(item_count / page_size))
|
|
|
|
|
|
|
|
|
|
|
|
async def find_by_id(self, id: int):
|
|
|
|
async def find_by_id(self, id: int):
|
|
|
|
stmt = select(BotPersonaModel).where(BotPersonaModel.id == id)
|
|
|
|
stmt = select(BotPersonaModel).where(BotPersonaModel.id == id)
|
|
|
|
return await self.session.scalar(stmt)
|
|
|
|
return await self.session.scalar(stmt)
|
|
|
@ -46,3 +71,12 @@ class BotPersonaHelper(BaseHelper):
|
|
|
|
async def find_by_bot_id(self, bot_id: str):
|
|
|
|
async def find_by_bot_id(self, bot_id: str):
|
|
|
|
stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id)
|
|
|
|
stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id)
|
|
|
|
return await self.session.scalar(stmt)
|
|
|
|
return await self.session.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_system_prompt(self, bot_id: str) -> str | None:
|
|
|
|
|
|
|
|
stmt = select(BotPersonaModel.system_prompt).where(BotPersonaModel.bot_id == bot_id)
|
|
|
|
|
|
|
|
return await self.session.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
async def get_cached_system_prompt(dbs: DatabaseService, bot_id: str) -> str | None:
|
|
|
|
|
|
|
|
async with BotPersonaHelper(dbs) as bot_persona_helper:
|
|
|
|
|
|
|
|
return await bot_persona_helper.get_system_prompt(bot_id)
|