from __future__ import annotations import hashlib from typing import Optional, Type import asyncpg from api.model.base import BaseModel import config import numpy as np import sqlalchemy from sqlalchemy import Index, select, update, delete, Select from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.ext.asyncio import AsyncSession from pgvector.asyncpg import register_vector from pgvector.sqlalchemy import Vector from service.database import DatabaseService page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {} class AbstractPageIndexModel(BaseModel): __abstract__ = True id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True) embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)) text: Mapped[str] = mapped_column(sqlalchemy.Text) text_len: Mapped[int] = mapped_column(sqlalchemy.Integer) markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True) markdown_len: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) temp_doc_session_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) def create_page_index_model(collection_id: int): if collection_id in page_index_model_list: return page_index_model_list[collection_id] else: class PageIndexModel(AbstractPageIndexModel): __tablename__ = "embedding_search_page_index_%s" % str(collection_id) embedding_index = sqlalchemy.Index(__tablename__ + "_embedding_idx", AbstractPageIndexModel.embedding, postgresql_using='ivfflat', postgresql_ops={'embedding': 'vector_cosine_ops'}) page_index_model_list[collection_id] = PageIndexModel return PageIndexModel class PageIndexHelper: columns = [ "id", "page_id" "sha1", "text", "text_len", "markdown", "markdown_len", "embedding", "temp_doc_session_id" ] def __init__(self, dbs: DatabaseService, collection_id: int, page_id: Optional[int]): self.dbs = dbs self.collection_id = collection_id self.page_id = page_id if page_id is not None else -1 self.table_name = "embedding_search_page_index_%s" % str(collection_id) self.initialized = False self.table_initialized = False """ Initialize table """ async def __aenter__(self): if self.initialized: return self.dbpool = self.dbs.pool.acquire() self.dbi = await self.dbpool.__aenter__() await register_vector(self.dbi) self.create_session = self.dbs.create_session self.session = self.dbs.create_session() await self.session.__aenter__() self.orm = create_page_index_model(self.collection_id) self.initialized = True return self async def __aexit__(self, exc_type, exc, tb): await self.dbpool.__aexit__(exc_type, exc, tb) await self.session.__aexit__(exc_type, exc, tb) async def table_exists(self): exists = await self.dbi.fetchval("""SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1 );""", self.table_name, column=0) return bool(exists) async def init_table(self): if self.table_initialized: return # create table if not exists if not await self.table_exists(): async with self.dbs.engine.begin() as conn: await conn.run_sync(self.orm.__table__.create) # await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ ( # id SERIAL PRIMARY KEY, # page_id INTEGER NOT NULL, # sha1 VARCHAR(40) NOT NULL, # text TEXT NOT NULL, # text_len INTEGER NOT NULL, # embedding VECTOR(%d) NOT NULL, # markdown TEXT NULL, # markdown_len INTEGER NULL, # temp_doc_session_id INTEGER NULL # ); # CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id); # CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1); # CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id); # """ % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name)) self.table_initialized = True async def create_embedding_index(self): pass # await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);" # .replace("/*_*/", self.table_name)) def sha1_doc(self, doc: list): for item in doc: if "sha1" not in item or not item["sha1"]: sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest() item["sha1"] = sha1 async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False): indexed_sha1_list = [] stmt = select(self.orm).column(self.orm.sha1) if not with_temporary: stmt = stmt.where(self.orm.temp_doc_session_id == None) if not in_collection: stmt = stmt.where(self.orm.page_id == self.page_id) ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt) for row in ret: indexed_sha1_list.append(row.sha1) return indexed_sha1_list async def get_unindexed_doc(self, doc: list, with_temporary: bool = True): indexed_sha1_list = await self.get_indexed_sha1(with_temporary) self.sha1_doc(doc) should_index = [] for item in doc: if item["sha1"] not in indexed_sha1_list: should_index.append(item) return should_index async def remove_outdated_doc(self, doc: list): await self.clear_temp() indexed_sha1_list = await self.get_indexed_sha1(False) self.sha1_doc(doc) doc_sha1_list = [item["sha1"] for item in doc] should_remove = [] for sha1 in indexed_sha1_list: if sha1 not in doc_sha1_list: should_remove.append(sha1) if len(should_remove) > 0: await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)" % (self.table_name), self.page_id, should_remove) async def index_doc(self, doc: list): need_create_index = False indexed_persist_sha1_list = [] indexed_temp_sha1_list = [] ret = await self.dbi.fetch("SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1" % (self.table_name), self.page_id) for row in ret: if row[1]: indexed_temp_sha1_list.append(row[0]) else: indexed_persist_sha1_list.append(row[0]) # Create index when no indexed document if len(indexed_persist_sha1_list) == 0: need_create_index = True self.sha1_doc(doc) doc_sha1_list = [] should_index = [] should_persist = [] should_remove = [] for item in doc: doc_sha1_list.append(item["sha1"]) if item["sha1"] in indexed_temp_sha1_list: should_persist.append(item["sha1"]) elif item["sha1"] not in indexed_persist_sha1_list: should_index.append(item) for sha1 in indexed_persist_sha1_list: if sha1 not in doc_sha1_list: should_remove.append(sha1) if len(should_index) > 0: await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name), [(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"]) for item in should_index]) if len(should_persist) > 0: await self.dbi.executemany("UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name), [(self.page_id, sha1) for sha1 in should_persist]) if need_create_index: await self.create_embedding_index() """ Add temporary document to the index """ async def index_temp_doc(self, doc: list, temp_doc_session_id: int): indexed_sha1_list = [] indexed_temp_sha1_list = [] doc_sha1_list = [] sql = "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1 AND (temp_doc_session_id = $2 OR temp_doc_session_id IS NULL)" % ( self.table_name) ret = await self.dbi.fetch(sql, self.page_id, temp_doc_session_id) for row in ret: indexed_sha1_list.append(row[0]) if row[1]: indexed_temp_sha1_list.append(row[0]) self.sha1_doc(doc) should_index = [] should_remove = [] for item in doc: sha1 = item["sha1"] doc_sha1_list.append(sha1) if sha1 not in indexed_sha1_list: should_index.append(item) for sha1 in indexed_temp_sha1_list: if sha1 not in doc_sha1_list: should_remove.append(sha1) if len(should_index) > 0: await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8);""" % (self.table_name), [(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"], temp_doc_session_id) for item in should_index]) if len(should_remove) > 0: await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND temp_doc_session_id = $2 AND sha1 = ANY($3)" % (self.table_name), self.page_id, temp_doc_session_id, should_remove) """ Search for text by consine similary """ async def search_text_embedding(self, embedding: np.ndarray, in_collection: bool = False, limit: int = 10): if in_collection: return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance FROM %s ORDER BY distance ASC LIMIT %d""" % (self.table_name, limit), embedding) else: return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance FROM %s WHERE page_id = $2 ORDER BY distance ASC LIMIT %d""" % (self.table_name, limit), embedding, self.page_id) """ Clear temporary index """ async def clear_temp(self, in_collection: bool = False, temp_doc_session_id: int = None): sql = "DELETE FROM %s" % (self.table_name) where = [] params = [] if not in_collection: params.append(self.page_id) where.append("page_id = $%d" % len(params)) if temp_doc_session_id: params.append(temp_doc_session_id) where.append("temp_doc_session_id = $%d" % len(params)) else: where.append("temp_doc_session_id IS NOT NULL") if len(where) > 0: sql += " WHERE " + (" AND ".join(where)) await self.dbi.execute(sql, *params)