You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
48 lines
1.9 KiB
Python
48 lines
1.9 KiB
Python
2 years ago
|
from __future__ import annotations
|
||
|
import sqlalchemy
|
||
|
from api.model.base import BaseHelper, BaseModel
|
||
|
|
||
|
import sqlalchemy
|
||
|
from sqlalchemy import select, update
|
||
|
from sqlalchemy.orm import mapped_column, relationship, Mapped
|
||
|
|
||
|
class BotPersonaModel(BaseModel):
|
||
|
__tablename__ = "chat_complete_bot_persona"
|
||
|
|
||
|
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||
|
bot_id: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
|
||
|
bot_name: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
|
||
|
bot_avatar: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
||
|
bot_description: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
||
|
system_prompt: Mapped[str] = mapped_column(sqlalchemy.String)
|
||
|
message_log: Mapped[list] = mapped_column(sqlalchemy.JSON)
|
||
|
default_question: Mapped[str] = mapped_column(sqlalchemy.String, nullable=True)
|
||
|
|
||
|
class BotPersonaHelper(BaseHelper):
|
||
|
async def add(self, obj: BotPersonaModel):
|
||
|
self.session.add(obj)
|
||
|
await self.session.commit()
|
||
|
await self.session.refresh(obj)
|
||
|
return obj
|
||
|
|
||
|
async def update(self, obj: BotPersonaModel):
|
||
|
obj = await self.session.merge(obj)
|
||
|
await self.session.commit()
|
||
|
return obj
|
||
|
|
||
|
async def get_list(self):
|
||
|
stmt = select(BotPersonaModel).with_only_columns([
|
||
|
BotPersonaModel.id,
|
||
|
BotPersonaModel.bot_name,
|
||
|
BotPersonaModel.bot_avatar,
|
||
|
BotPersonaModel.bot_description
|
||
|
])
|
||
|
return await self.session.scalars(stmt)
|
||
|
|
||
|
async def find_by_id(self, id: int):
|
||
|
stmt = select(BotPersonaModel).where(BotPersonaModel.id == id)
|
||
|
return await self.session.scalar(stmt)
|
||
|
|
||
|
async def find_by_bot_id(self, bot_id: str):
|
||
|
stmt = select(BotPersonaModel).where(BotPersonaModel.bot_id == bot_id)
|
||
|
return await self.session.scalar(stmt)
|