from __future__ import annotations import hashlib from typing import Optional, Type import asyncpg from api.model.base import BaseModel import numpy as np import sqlalchemy from config import Config from sqlalchemy import Index, select, update, delete, Select from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.ext.asyncio import AsyncSession from pgvector.asyncpg import register_vector from pgvector.sqlalchemy import Vector from service.database import DatabaseService page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {} embedding_vector_size = Config.get("chatcomplete.embedding_vector_size", 512, int) class AbstractPageIndexModel(BaseModel): __abstract__ = True id: Mapped[int] = mapped_column( sqlalchemy.Integer, primary_key=True, autoincrement=True ) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True) embedding: Mapped[np.ndarray] = mapped_column(Vector(embedding_vector_size)) text: Mapped[str] = mapped_column(sqlalchemy.Text) text_len: Mapped[int] = mapped_column(sqlalchemy.Integer) markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True) markdown_len: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) temp_doc_session_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) def create_page_index_model(collection_id: int): if collection_id in page_index_model_list: return page_index_model_list[collection_id] else: class _PageIndexModel(AbstractPageIndexModel): __tablename__ = "embedding_search_page_index_%s" % str(collection_id) embedding_index = sqlalchemy.Index( __tablename__ + "_embedding_idx", AbstractPageIndexModel.embedding, postgresql_using="ivfflat", postgresql_ops={"embedding": "vector_cosine_ops"}, ) page_index_model_list[collection_id] = _PageIndexModel return _PageIndexModel class PageIndexHelper: columns = [ "id", "page_id" "sha1", "text", "text_len", "markdown", "markdown_len", "embedding", "temp_doc_session_id", ] def __init__(self, dbs: DatabaseService, collection_id: int): self.dbs = dbs self.collection_id = collection_id self.table_name = "embedding_search_page_index_%s" % str(collection_id) self.initialized = False self.table_initialized = False """ Initialize table """ async def __aenter__(self): if self.initialized: return self.dbpool = self.dbs.pool.acquire() self.dbi = await self.dbpool.__aenter__() await register_vector(self.dbi) self.create_session = self.dbs.create_session self.session = self.dbs.create_session() await self.session.__aenter__() self.orm = create_page_index_model(self.collection_id) self.initialized = True return self async def __aexit__(self, exc_type, exc, tb): await self.dbpool.__aexit__(exc_type, exc, tb) await self.session.__aexit__(exc_type, exc, tb) async def table_exists(self): exists = await self.dbi.fetchval( """SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1 );""", self.table_name, column=0, ) return bool(exists) async def init_table(self): if self.table_initialized: return # create table if not exists if not await self.table_exists(): async with self.dbs.engine.begin() as conn: await conn.run_sync(self.orm.__table__.create) self.table_initialized = True async def create_embedding_index(self): pass # await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);" # .replace("/*_*/", self.table_name)) def sha1_doc(self, doc: list): for item in doc: if "sha1" not in item or not item["sha1"]: sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest() item["sha1"] = sha1 async def get_indexed_sha1( self, page_id: int, with_temporary: bool = True, in_collection: bool = False ): indexed_sha1_list = [] stmt = select(self.orm).column(self.orm.sha1) if not with_temporary: stmt = stmt.where(self.orm.temp_doc_session_id == None) if not in_collection: stmt = stmt.where(self.orm.page_id == page_id) ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt) for row in ret: indexed_sha1_list.append(row.sha1) return indexed_sha1_list async def get_unindexed_doc( self, doc: list, page_id: int, with_temporary: bool = True ): indexed_sha1_list = await self.get_indexed_sha1(page_id, with_temporary) self.sha1_doc(doc) should_index = [] for item in doc: if item["sha1"] not in indexed_sha1_list: should_index.append(item) return should_index async def remove_outdated_doc(self, doc: list, page_id: int): await self.clear_temp(page_id=page_id) indexed_sha1_list = await self.get_indexed_sha1(page_id, False) self.sha1_doc(doc) doc_sha1_list = [item["sha1"] for item in doc] should_remove = [] for sha1 in indexed_sha1_list: if sha1 not in doc_sha1_list: should_remove.append(sha1) if len(should_remove) > 0: await self.dbi.execute( "DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)" % (self.table_name), page_id, should_remove, ) async def index_doc(self, doc: list, page_id: int): need_create_index = False indexed_persist_sha1_list = [] indexed_temp_sha1_list = [] ret = await self.dbi.fetch( "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1" % (self.table_name), page_id, ) for row in ret: if row[1]: indexed_temp_sha1_list.append(row[0]) else: indexed_persist_sha1_list.append(row[0]) # Create index when no indexed document if len(indexed_persist_sha1_list) == 0: need_create_index = True self.sha1_doc(doc) doc_sha1_list = [] should_index = [] should_persist = [] should_remove = [] for item in doc: doc_sha1_list.append(item["sha1"]) if item["sha1"] in indexed_temp_sha1_list: should_persist.append(item["sha1"]) elif item["sha1"] not in indexed_persist_sha1_list: should_index.append(item) for sha1 in indexed_persist_sha1_list: if sha1 not in doc_sha1_list: should_remove.append(sha1) if len(should_index) > 0: await self.dbi.executemany( """INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name), [ ( item["sha1"], page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"], ) for item in should_index ], ) if len(should_persist) > 0: await self.dbi.executemany( "UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name), [(page_id, sha1) for sha1 in should_persist], ) if need_create_index: await self.create_embedding_index() """ Add temporary document to the index """ async def index_temp_doc(self, doc: list, temp_doc_session_id: int): indexed_sha1_list = [] indexed_temp_sha1_list = [] doc_sha1_list = [] sql = ( "SELECT sha1, temp_doc_session_id FROM %s WHERE (temp_doc_session_id = $1 OR temp_doc_session_id IS NULL)" % (self.table_name) ) ret = await self.dbi.fetch(sql, temp_doc_session_id) for row in ret: indexed_sha1_list.append(row[0]) if row[1]: indexed_temp_sha1_list.append(row[0]) self.sha1_doc(doc) should_index = [] should_remove = [] for item in doc: sha1 = item["sha1"] doc_sha1_list.append(sha1) if sha1 not in indexed_sha1_list: should_index.append(item) for sha1 in indexed_temp_sha1_list: if sha1 not in doc_sha1_list: should_remove.append(sha1) if len(should_index) > 0: await self.dbi.executemany( """INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8);""" % (self.table_name), [ ( item["sha1"], item["page_id"], item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"], temp_doc_session_id, ) for item in should_index ], ) if len(should_remove) > 0: await self.dbi.execute( "DELETE FROM %s WHERE temp_doc_session_id = $1 AND sha1 = ANY($2)" % (self.table_name), temp_doc_session_id, should_remove, ) """ Search for text by consine similary """ async def search_text_embedding( self, embedding: np.ndarray, in_collection: bool = False, limit: int = 10, page_id: Optional[int] = None, ): if in_collection: return await self.dbi.fetch( """SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance FROM %s ORDER BY distance ASC LIMIT %d""" % (self.table_name, limit), embedding, ) else: return await self.dbi.fetch( """SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance FROM %s WHERE page_id = $2 ORDER BY distance ASC LIMIT %d""" % (self.table_name, limit), embedding, page_id, ) """ Clear temporary index """ async def clear_temp( self, in_collection: bool = False, temp_doc_session_id: int = None, page_id: Optional[int] = None ): sql = "DELETE FROM %s" % (self.table_name) where = [] params = [] if not in_collection and page_id: params.append(page_id) where.append("page_id = $%d" % len(params)) if temp_doc_session_id: params.append(temp_doc_session_id) where.append("temp_doc_session_id = $%d" % len(params)) else: where.append("temp_doc_session_id IS NOT NULL") if len(where) > 0: sql += " WHERE " + (" AND ".join(where)) await self.dbi.execute(sql, *params) async def remove_by_page_id(self, page_id: int): stmt = delete(self.orm).where(self.orm.page_id == page_id) await self.session.execute(stmt) await self.session.commit()