You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
2.4 KiB
Python
64 lines
2.4 KiB
Python
from typing import Optional, Union
|
|
import sqlalchemy
|
|
from sqlalchemy import select, update, delete
|
|
from sqlalchemy.orm import mapped_column, Mapped
|
|
|
|
from api.model.base import BaseModel
|
|
from service.database import DatabaseService
|
|
|
|
class TitleCollectionModel(BaseModel):
|
|
__tablename__ = "embedding_search_title_collection"
|
|
|
|
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
|
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
|
|
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
|
|
|
class TitleCollectionHelper:
|
|
def __init__(self, dbs: DatabaseService):
|
|
self.dbs = dbs
|
|
self.initialized = False
|
|
|
|
async def __aenter__(self):
|
|
if not self.initialized:
|
|
self.create_session = self.dbs.create_session
|
|
self.session = self.dbs.create_session()
|
|
await self.session.__aenter__()
|
|
self.initialized = True
|
|
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
await self.session.__aexit__(exc_type, exc, tb)
|
|
pass
|
|
|
|
async def add(self, title: str, page_id: Optional[int] = None) -> Union[int, bool]:
|
|
stmt = select(TitleCollectionModel.id).where(TitleCollectionModel.title == title)
|
|
result = await self.session.scalar(stmt)
|
|
|
|
if result is None:
|
|
obj = TitleCollectionModel(title=title, page_id=page_id)
|
|
self.session.add(obj)
|
|
await self.session.commit()
|
|
await self.session.refresh(obj)
|
|
return obj.id
|
|
|
|
return False
|
|
|
|
async def set_page_id(self, title: str, page_id: Optional[str] = None):
|
|
stmt = update(TitleCollectionModel).where(TitleCollectionModel.title == title).values(page_id=page_id)
|
|
await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
async def remove(self, title: str):
|
|
stmt = delete(TitleCollectionModel).where(TitleCollectionModel.title == title)
|
|
await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
async def find_by_title(self, title: str):
|
|
stmt = select(TitleCollectionModel).where(TitleCollectionModel.title == title)
|
|
return await self.session.scalar(stmt)
|
|
|
|
async def find_by_page_id(self, page_id: int):
|
|
stmt = select(TitleCollectionModel).where(TitleCollectionModel.page_id == page_id)
|
|
return await self.session.scalar(stmt)
|