from typing import Optional, Union import sqlalchemy from sqlalchemy import select, update, delete from sqlalchemy.orm import mapped_column, Mapped from api.model.base import BaseModel from service.database import DatabaseService class TitleCollectionModel(BaseModel): __tablename__ = "embedding_search_title_collection" id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) class TitleCollectionHelper: def __init__(self, dbs: DatabaseService): self.dbs = dbs self.initialized = False async def __aenter__(self): if not self.initialized: self.create_session = self.dbs.create_session self.session = self.dbs.create_session() await self.session.__aenter__() self.initialized = True return self async def __aexit__(self, exc_type, exc, tb): await self.session.__aexit__(exc_type, exc, tb) pass async def add(self, title: str, page_id: Optional[int] = None) -> Union[int, bool]: stmt = select(TitleCollectionModel.id).where(TitleCollectionModel.title == title) result = await self.session.scalar(stmt) if result is None: obj = TitleCollectionModel(title=title, page_id=page_id) self.session.add(obj) await self.session.commit() await self.session.refresh(obj) return obj.id return False async def set_page_id(self, title: str, page_id: Optional[str] = None): stmt = update(TitleCollectionModel).where(TitleCollectionModel.title == title).values(page_id=page_id) await self.session.execute(stmt) await self.session.commit() async def remove(self, title: str): stmt = delete(TitleCollectionModel).where(TitleCollectionModel.title == title) await self.session.execute(stmt) await self.session.commit() async def find_by_title(self, title: str): stmt = select(TitleCollectionModel).where(TitleCollectionModel.title == title) return await self.session.scalar(stmt) async def find_by_page_id(self, page_id: int): stmt = select(TitleCollectionModel).where(TitleCollectionModel.page_id == page_id) return await self.session.scalar(stmt)