You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
2.1 KiB
Python
51 lines
2.1 KiB
Python
from typing import Optional, Union
|
|
import sqlalchemy
|
|
from sqlalchemy import select, update, delete
|
|
from sqlalchemy.orm import mapped_column, Mapped
|
|
|
|
from api.model.base import BaseHelper, BaseModel
|
|
from service.database import DatabaseService
|
|
|
|
class TitleCollectionModel(BaseModel):
|
|
__tablename__ = "embedding_search_title_collection"
|
|
|
|
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
|
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
|
|
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, nullable=True)
|
|
|
|
class TitleCollectionHelper(BaseHelper):
|
|
async def add(self, title: str, page_id: Optional[int] = None) -> TitleCollectionModel | None:
|
|
stmt = select(TitleCollectionModel.id).where(TitleCollectionModel.title == title)
|
|
result = await self.session.scalar(stmt)
|
|
|
|
if result is None:
|
|
obj = TitleCollectionModel(title=title, page_id=page_id)
|
|
self.session.add(obj)
|
|
await self.session.commit()
|
|
await self.session.refresh(obj)
|
|
return obj
|
|
|
|
return None
|
|
|
|
async def set_main_page_id(self, title: str, page_id: Optional[str] = None):
|
|
stmt = update(TitleCollectionModel).where(TitleCollectionModel.title == title).values(page_id=page_id)
|
|
await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
async def remove(self, title: str):
|
|
stmt = delete(TitleCollectionModel).where(TitleCollectionModel.title == title)
|
|
await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
async def find_by_id(self, id: int):
|
|
stmt = select(TitleCollectionModel).where(TitleCollectionModel.id == id)
|
|
return await self.session.scalar(stmt)
|
|
|
|
async def find_by_title(self, title: str):
|
|
stmt = select(TitleCollectionModel).where(TitleCollectionModel.title == title)
|
|
return await self.session.scalar(stmt)
|
|
|
|
async def find_by_page_id(self, page_id: int):
|
|
stmt = select(TitleCollectionModel).where(TitleCollectionModel.page_id == page_id)
|
|
return await self.session.scalar(stmt)
|