You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

51 lines
2.1 KiB
Python

from typing import Optional, Union
import sqlalchemy
from sqlalchemy import select, update, delete
from sqlalchemy.orm import mapped_column, Mapped
from api.model.base import BaseHelper, BaseModel
from service.database import DatabaseService
class TitleCollectionModel(BaseModel):
__tablename__ = "embedding_search_title_collection"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, nullable=True)
class TitleCollectionHelper(BaseHelper):
async def add(self, title: str, page_id: Optional[int] = None) -> TitleCollectionModel | None:
stmt = select(TitleCollectionModel.id).where(TitleCollectionModel.title == title)
result = await self.session.scalar(stmt)
if result is None:
obj = TitleCollectionModel(title=title, page_id=page_id)
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj
return None
async def set_main_page_id(self, title: str, page_id: Optional[str] = None):
stmt = update(TitleCollectionModel).where(TitleCollectionModel.title == title).values(page_id=page_id)
await self.session.execute(stmt)
await self.session.commit()
async def remove(self, title: str):
stmt = delete(TitleCollectionModel).where(TitleCollectionModel.title == title)
await self.session.execute(stmt)
await self.session.commit()
async def find_by_id(self, id: int):
stmt = select(TitleCollectionModel).where(TitleCollectionModel.id == id)
return await self.session.scalar(stmt)
async def find_by_title(self, title: str):
stmt = select(TitleCollectionModel).where(TitleCollectionModel.title == title)
return await self.session.scalar(stmt)
async def find_by_page_id(self, page_id: int):
stmt = select(TitleCollectionModel).where(TitleCollectionModel.page_id == page_id)
return await self.session.scalar(stmt)