import hashlib import asyncpg import numpy as np from pgvector.sqlalchemy import Vector from pgvector.asyncpg import register_vector import sqlalchemy from sqlalchemy.orm import mapped_column, relationship, Mapped from sqlalchemy.ext.asyncio import AsyncEngine import config from api.model.base import BaseModel from service.database import DatabaseService class TitleIndexModel(BaseModel): __tablename__ = "embedding_search_title_index" id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) sha1: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)) embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding, postgresql_using='ivfflat', postgresql_ops={'embedding': 'vector_cosine_ops'}) class TitleIndexHelper: __tablename__ = "embedding_search_title_index" columns = [ "id", "sha1", "title", "page_id", "collection_id", "rev_id", "embedding", ] def __init__(self, dbs: DatabaseService): self.dbs = dbs self.initialized = False async def __aenter__(self): if not self.initialized: self.dbpool = self.dbs.pool.acquire() self.dbi = await self.dbpool.__aenter__() await register_vector(self.dbi) self.initialized = True return self async def __aexit__(self, exc_type, exc, tb): await self.dbpool.__aexit__(exc_type, exc, tb) def get_columns(self, exclude=[]): if len(exclude) == 0: return ", ".join(self.columns) return ", ".join([col for col in self.columns if col not in exclude]) """ Add a title to the index """ async def add(self, title: str, page_id: int, rev_id: int, collection_id: int, embedding: np.ndarray): title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() ret = await self.dbi.fetchrow("SELECT * FROM embedding_search_title_index WHERE sha1 = $1", title_sha1) if ret is None: new_id = await self.dbi.fetchval("""INSERT INTO embedding_search_title_index (sha1, title, page_id, rev_id, collection_id, embedding) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id""", title_sha1, title, page_id, rev_id, collection_id, embedding, column=0) return new_id return False """ Remove a title from the index """ async def remove(self, title: str): title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() await self.dbi.execute("DELETE FROM embedding_search_title_index WHERE sha1 = $1", title_sha1) """ Update the indexed revision id of a title """ async def update_rev_id(self, page_id: int, rev_id: int): await self.dbi.execute("UPDATE embedding_search_title_index SET rev_id = $1 WHERE page_id = $2", rev_id, page_id) """ Update title data """ async def update_title_data(self, page_id: int, title: str, rev_id: int, collection_id: int, embedding: np.ndarray): if collection_page_id is None: collection_page_id = page_id await self.dbi.execute("""UPDATE embedding_search_title_index SET title = $1, rev_id = $2, collection_id = $3, embedding = $4 WHERE page_id = $5""", title, rev_id, collection_id, embedding, page_id) """ Search for titles by consine similary """ async def search_title_embedding(self, embedding: np.ndarray, limit: int = 10): ret = self.dbi.fetch("""SELECT %s, embedding <-> $1 AS distance FROM embedding_search_title_index ORDER BY distance DESC LIMIT %d""" % (self.get_columns(exclude=['embedding']), limit), embedding) return ret """ Find a title in the index """ async def find_by_title(self, title: str, with_embedding=False): title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() if with_embedding: columns = self.get_columns() else: columns = self.get_columns(exclude=["embedding"]) ret = await self.dbi.fetchrow( "SELECT %s FROM embedding_search_title_index WHERE sha1 = $1" % columns, title_sha1 ) return ret async def find_by_page_id(self, page_id: int, with_embedding=False): if with_embedding: columns = self.get_columns() else: columns = self.get_columns(exclude=["embedding"]) ret = await self.dbi.fetchrow( "SELECT %s FROM embedding_search_title_index WHERE page_id = $1" % columns, page_id ) return ret async def find_by_collection_id(self, collection_id: int, with_embedding=False): if with_embedding: columns = self.get_columns() else: columns = self.get_columns(exclude=["embedding"]) ret = await self.dbi.fetch( "SELECT %s FROM embedding_search_title_index WHERE collection_id = $1" % columns, collection_id ) return ret