You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
98 lines
3.8 KiB
Python
98 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import List, Optional
|
|
import sqlalchemy
|
|
from sqlalchemy import update
|
|
from sqlalchemy.orm import mapped_column, Mapped
|
|
|
|
from api.model.base import BaseModel
|
|
from service.database import DatabaseService
|
|
|
|
|
|
class ConversationModel(BaseModel):
|
|
__tablename__ = "toolkit_ui_conversation"
|
|
|
|
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
|
user_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
|
module: Mapped[str] = mapped_column(sqlalchemy.String(60), index=True)
|
|
title: Mapped[str] = mapped_column(sqlalchemy.String(255), nullable=True)
|
|
thumbnail: Mapped[str] = mapped_column(sqlalchemy.Text(), nullable=True)
|
|
page_id: Mapped[int] = mapped_column(
|
|
sqlalchemy.Integer, index=True, nullable=True)
|
|
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
|
|
updated_at: Mapped[int] = mapped_column(
|
|
sqlalchemy.TIMESTAMP, index=True, server_default=sqlalchemy.func.now())
|
|
pinned: Mapped[bool] = mapped_column(
|
|
sqlalchemy.Boolean, default=False, index=True)
|
|
extra: Mapped[dict] = mapped_column(sqlalchemy.JSON, default={})
|
|
|
|
|
|
class ConversationHelper:
|
|
def __init__(self, dbs: DatabaseService):
|
|
self.dbs = dbs
|
|
self.initialized = False
|
|
|
|
async def __aenter__(self):
|
|
if not self.initialized:
|
|
self.create_session = self.dbs.create_session
|
|
self.session = self.dbs.create_session()
|
|
await self.session.__aenter__()
|
|
|
|
self.initialized = True
|
|
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
await self.session.__aexit__(exc_type, exc, tb)
|
|
pass
|
|
|
|
async def add(self, user_id: int, module: str, title: Optional[str] = None, page_id: Optional[int] = None, rev_id: Optional[int] = None, extra: Optional[dict] = None):
|
|
obj = ConversationModel(user_id=user_id, module=module, title=title,
|
|
page_id=page_id, rev_id=rev_id, updated_at=sqlalchemy.func.current_timestamp())
|
|
|
|
if extra is not None:
|
|
obj.extra = extra
|
|
|
|
self.session.add(obj)
|
|
await self.session.commit()
|
|
await self.session.refresh(obj)
|
|
return obj
|
|
|
|
async def refresh_updated_at(self, conversation_id: int):
|
|
stmt = update(ConversationModel).where(ConversationModel.id ==
|
|
conversation_id).values(updated_at=sqlalchemy.func.current_timestamp())
|
|
await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
async def update(self, obj: ConversationModel):
|
|
await self.session.merge(obj)
|
|
await self.session.commit()
|
|
await self.session.refresh(obj)
|
|
return obj
|
|
|
|
async def get_conversation_list(self, user_id: int, module: Optional[str] = None, page_id: Optional[int] = None) -> List[ConversationModel]:
|
|
stmt = sqlalchemy.select(ConversationModel) \
|
|
.where(ConversationModel.user_id == user_id)
|
|
|
|
if module is not None:
|
|
stmt = stmt.where(ConversationModel.module == module)
|
|
|
|
if page_id is not None:
|
|
stmt = stmt.where(ConversationModel.page_id == page_id)
|
|
|
|
stmt = stmt.order_by(ConversationModel.pinned.desc(),
|
|
ConversationModel.updated_at.desc())
|
|
|
|
return await self.session.scalars(stmt)
|
|
|
|
async def find_by_id(self, conversation_id: int):
|
|
async with self.create_session() as session:
|
|
stmt = sqlalchemy.select(ConversationModel).where(
|
|
ConversationModel.id == conversation_id)
|
|
return await session.scalar(stmt)
|
|
|
|
async def remove(self, conversation_id: int):
|
|
stmt = sqlalchemy.delete(ConversationModel).where(
|
|
ConversationModel.id == conversation_id)
|
|
await self.session.execute(stmt)
|
|
await self.session.commit() |