from __future__ import annotations import hashlib import asyncpg import numpy as np from pgvector.sqlalchemy import Vector from pgvector.asyncpg import register_vector import sqlalchemy from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer from sqlalchemy.ext.asyncio import AsyncEngine import config from api.model.base import BaseHelper, BaseModel from service.database import DatabaseService class TitleIndexModel(BaseModel): __tablename__ = "embedding_search_title_index" id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) sha1: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) indexed_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer) embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE), nullable=True)) embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding, postgresql_using='ivfflat', postgresql_ops={'embedding': 'vector_cosine_ops'}) class TitleIndexHelper(BaseHelper): __tablename__ = "embedding_search_title_index" columns = [ "id", "sha1", "title", "page_id", "collection_id", "indexed_rev_id", "latest_rev_id" "embedding", ] def __init__(self, dbs: DatabaseService): super().__init__(dbs) async def __aenter__(self): if not self.initialized: self.dbpool = self.dbs.pool.acquire() self.dbi = await self.dbpool.__aenter__() await register_vector(self.dbi) return await super().__aenter__() async def __aexit__(self, exc_type, exc, tb): await self.dbpool.__aexit__(exc_type, exc, tb) return await super().__aexit__(exc_type, exc, tb) def get_columns(self, exclude=[]): if len(exclude) == 0: return ", ".join(self.columns) return ", ".join([col for col in self.columns if col not in exclude]) """ Add a title to the index """ async def add(self, obj: TitleIndexModel): obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest() if await self.find_by_sha1(obj.sha1) is None: self.session.add(obj) await self.session.commit() await self.session.refresh(obj) return obj return False """ Remove a title from the index """ async def remove(self, title: str): title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() stmt = sqlalchemy.delete(TitleIndexModel).where(TitleIndexModel.sha1 == title_sha1) await self.session.execute(stmt) async def update(self, obj: TitleIndexModel): obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest() await self.session.merge(obj) await self.session.commit() await self.session.refresh(obj) return obj async def refresh(self, obj: TitleIndexModel): await self.session.refresh(obj) return obj """ Search for titles by consine similary """ async def search_title_embedding(self, embedding: np.ndarray, limit: int = 10): ret = self.dbi.fetch("""SELECT %s, embedding <-> $1 AS distance FROM embedding_search_title_index ORDER BY distance DESC LIMIT %d""" % (self.get_columns(exclude=['embedding']), limit), embedding) return ret """ Find a title in the index """ async def find_by_title(self, title: str): sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() return await self.find_by_sha1(sha1) async def find_by_sha1(self, sha1: str): stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.sha1 == sha1) return await self.session.scalar(stmt) async def find_by_page_id(self, page_id: int): stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.page_id == page_id) return await self.session.scalar(stmt) async def find_list_by_collection_id(self, collection_id: int) -> sqlalchemy.ScalarResult[TitleIndexModel]: stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.collection_id == collection_id) return await self.session.scalars(stmt) async def get_need_update_index_list(self, collection_id: int) -> list[TitleIndexModel]: page_list = await self.find_list_by_collection_id(collection_id) result: list[TitleIndexModel] = [] for page_info in page_list: if page_info.indexed_rev_id != page_info.latest_rev_id: result.append(page_info) return result