import hashlib from typing import Optional import asyncpg from api.model.base import BaseModel import config import numpy as np import sqlalchemy from sqlalchemy import select, update, delete from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.ext.asyncio import AsyncSession from pgvector.asyncpg import register_vector from pgvector.sqlalchemy import Vector from service.database import DatabaseService class PageIndexModel(BaseModel): __abstract__ = True id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True) text: Mapped[str] = mapped_column(sqlalchemy.Text) text_len: Mapped[int] = mapped_column(sqlalchemy.Integer) embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)) markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True) markdown_len: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) temp_doc_session_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) class PageIndexHelper: columns = [ "id", "page_id" "sha1", "text", "text_len", "markdown", "markdown_len", "embedding", "temp_doc_session_id" ] def __init__(self, dbs: DatabaseService, collection_id: int, page_id: Optional[int]): self.dbs = dbs self.collection_id = collection_id self.page_id = page_id if page_id is not None else -1 self.table_name = "embedding_search_page_index_%s" % str(collection_id) self.initialized = False self.table_initialized = False """ Initialize table """ async def __aenter__(self): if self.initialized: return self.dbpool = self.dbs.pool.acquire() self.dbi = await self.dbpool.__aenter__() await register_vector(self.dbi) self.initialized = True return self async def __aexit__(self, exc_type, exc, tb): await self.dbpool.__aexit__(exc_type, exc, tb) async def table_exists(self): exists = await self.dbi.fetchval("""SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1 );""", self.table_name, column=0) return bool(exists) async def init_table(self): if self.table_initialized: return # create table if not exists if not await self.table_exists(): await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ ( id SERIAL PRIMARY KEY, page_id INTEGER NOT NULL, sha1 VARCHAR(40) NOT NULL, text TEXT NOT NULL, text_len INTEGER NOT NULL, embedding VECTOR(%d) NOT NULL, markdown TEXT NULL, markdown_len INTEGER NULL, temp_doc_session_id INTEGER NULL ); CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id); CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1); CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id); """ % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name)) self.table_initialized = False async def create_embedding_index(self): await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);" .replace("/*_*/", self.table_name)) def sha1_doc(self, doc: list): for item in doc: if "sha1" not in item or not item["sha1"]: sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest() item["sha1"] = sha1 async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False): indexed_sha1_list = [] sql = "SELECT sha1 FROM %s" % (self.table_name) where = [] params = [] if not with_temporary: where.append("temp_doc_session_id IS NULL") if not in_collection: params.append(self.page_id) where.append("page_id = $%d" % len(params)) if len(where) > 0: sql += " WHERE " + (" AND ".join(where)) ret = await self.dbi.fetch(sql, *params) for row in ret: indexed_sha1_list.append(row[0]) return indexed_sha1_list async def get_unindexed_doc(self, doc: list, with_temporary: bool = True): indexed_sha1_list = await self.get_indexed_sha1(with_temporary) self.sha1_doc(doc) should_index = [] for item in doc: if item["sha1"] not in indexed_sha1_list: should_index.append(item) return should_index async def remove_outdated_doc(self, doc: list): await self.clear_temp() indexed_sha1_list = await self.get_indexed_sha1(False) self.sha1_doc(doc) doc_sha1_list = [item["sha1"] for item in doc] should_remove = [] for sha1 in indexed_sha1_list: if sha1 not in doc_sha1_list: should_remove.append(sha1) if len(should_remove) > 0: await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)" % (self.table_name), self.page_id, should_remove) async def index_doc(self, doc: list): need_create_index = False indexed_persist_sha1_list = [] indexed_temp_sha1_list = [] ret = await self.dbi.fetch("SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1" % (self.table_name), self.page_id) for row in ret: if row[1]: indexed_temp_sha1_list.append(row[0]) else: indexed_persist_sha1_list.append(row[0]) # Create index when no indexed document if len(indexed_persist_sha1_list) == 0: need_create_index = True self.sha1_doc(doc) doc_sha1_list = [] should_index = [] should_persist = [] should_remove = [] for item in doc: doc_sha1_list.append(item["sha1"]) if item["sha1"] in indexed_temp_sha1_list: should_persist.append(item["sha1"]) elif item["sha1"] not in indexed_persist_sha1_list: should_index.append(item) for sha1 in indexed_persist_sha1_list: if sha1 not in doc_sha1_list: should_remove.append(sha1) if len(should_index) > 0: await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name), [(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"]) for item in should_index]) if len(should_persist) > 0: await self.dbi.executemany("UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name), [(self.page_id, sha1) for sha1 in should_persist]) if need_create_index: await self.create_embedding_index() """ Add temporary document to the index """ async def index_temp_doc(self, doc: list, temp_doc_session_id: int): indexed_sha1_list = [] indexed_temp_sha1_list = [] doc_sha1_list = [] sql = "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1 AND (temp_doc_session_id = $2 OR temp_doc_session_id IS NULL)" % ( self.table_name) ret = await self.dbi.fetch(sql, self.page_id, temp_doc_session_id) for row in ret: indexed_sha1_list.append(row[0]) if row[1]: indexed_temp_sha1_list.append(row[0]) self.sha1_doc(doc) should_index = [] should_remove = [] for item in doc: sha1 = item["sha1"] doc_sha1_list.append(sha1) if sha1 not in indexed_sha1_list: should_index.append(item) for sha1 in indexed_temp_sha1_list: if sha1 not in doc_sha1_list: should_remove.append(sha1) if len(should_index) > 0: await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8);""" % (self.table_name), [(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"], temp_doc_session_id) for item in should_index]) if len(should_remove) > 0: await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND temp_doc_session_id = $2 AND sha1 = ANY($3)" % (self.table_name), self.page_id, temp_doc_session_id, should_remove) """ Search for text by consine similary """ async def search_text_embedding(self, embedding: np.ndarray, in_collection: bool = False, limit: int = 10): if in_collection: return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance FROM %s ORDER BY distance ASC LIMIT %d""" % (self.table_name, limit), embedding) else: return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance FROM %s WHERE page_id = $2 ORDER BY distance ASC LIMIT %d""" % (self.table_name, limit), embedding, self.page_id) """ Clear temporary index """ async def clear_temp(self, in_collection: bool = False, temp_doc_session_id: int = None): sql = "DELETE FROM %s" % (self.table_name) where = [] params = [] if not in_collection: params.append(self.page_id) where.append("page_id = $%d" % len(params)) if temp_doc_session_id: params.append(temp_doc_session_id) where.append("temp_doc_session_id = $%d" % len(params)) else: where.append("temp_doc_session_id IS NOT NULL") if len(where) > 0: sql += " WHERE " + (" AND ".join(where)) await self.dbi.execute(sql, *params)