|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
|
|
from typing import Optional, Type
|
|
|
|
|
|
|
|
import asyncpg
|
|
|
|
from api.model.base import BaseModel
|
|
|
|
import config
|
|
|
|
import numpy as np
|
|
|
|
import sqlalchemy
|
|
|
|
from sqlalchemy import Index, select, update, delete, Select
|
|
|
|
from sqlalchemy.orm import mapped_column, Mapped
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from pgvector.asyncpg import register_vector
|
|
|
|
from pgvector.sqlalchemy import Vector
|
|
|
|
|
|
|
|
from service.database import DatabaseService
|
|
|
|
|
|
|
|
page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
class AbstractPageIndexModel(BaseModel):
|
|
|
|
__abstract__ = True
|
|
|
|
|
|
|
|
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
|
|
|
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
|
|
|
sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True)
|
|
|
|
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
|
|
|
|
text: Mapped[str] = mapped_column(sqlalchemy.Text)
|
|
|
|
text_len: Mapped[int] = mapped_column(sqlalchemy.Integer)
|
|
|
|
markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True)
|
|
|
|
markdown_len: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
|
|
|
|
temp_doc_session_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
|
|
|
|
|
|
|
|
|
|
|
|
def create_page_index_model(collection_id: int):
|
|
|
|
if collection_id in page_index_model_list:
|
|
|
|
return page_index_model_list[collection_id]
|
|
|
|
else:
|
|
|
|
class PageIndexModel(AbstractPageIndexModel):
|
|
|
|
__tablename__ = "embedding_search_page_index_%s" % str(collection_id)
|
|
|
|
|
|
|
|
embedding_index = sqlalchemy.Index(__tablename__ + "_embedding_idx", AbstractPageIndexModel.embedding,
|
|
|
|
postgresql_using='ivfflat',
|
|
|
|
postgresql_ops={'embedding': 'vector_cosine_ops'})
|
|
|
|
|
|
|
|
page_index_model_list[collection_id] = PageIndexModel
|
|
|
|
|
|
|
|
return PageIndexModel
|
|
|
|
|
|
|
|
class PageIndexHelper:
|
|
|
|
columns = [
|
|
|
|
"id",
|
|
|
|
"page_id"
|
|
|
|
"sha1",
|
|
|
|
"text",
|
|
|
|
"text_len",
|
|
|
|
"markdown",
|
|
|
|
"markdown_len",
|
|
|
|
"embedding",
|
|
|
|
"temp_doc_session_id"
|
|
|
|
]
|
|
|
|
|
|
|
|
def __init__(self, dbs: DatabaseService, collection_id: int, page_id: Optional[int]):
|
|
|
|
self.dbs = dbs
|
|
|
|
self.collection_id = collection_id
|
|
|
|
self.page_id = page_id if page_id is not None else -1
|
|
|
|
self.table_name = "embedding_search_page_index_%s" % str(collection_id)
|
|
|
|
self.initialized = False
|
|
|
|
self.table_initialized = False
|
|
|
|
|
|
|
|
"""
|
|
|
|
Initialize table
|
|
|
|
"""
|
|
|
|
async def __aenter__(self):
|
|
|
|
if self.initialized:
|
|
|
|
return
|
|
|
|
|
|
|
|
self.dbpool = self.dbs.pool.acquire()
|
|
|
|
self.dbi = await self.dbpool.__aenter__()
|
|
|
|
|
|
|
|
await register_vector(self.dbi)
|
|
|
|
|
|
|
|
self.create_session = self.dbs.create_session
|
|
|
|
self.session = self.dbs.create_session()
|
|
|
|
await self.session.__aenter__()
|
|
|
|
|
|
|
|
self.orm = create_page_index_model(self.collection_id)
|
|
|
|
|
|
|
|
self.initialized = True
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
|
|
await self.dbpool.__aexit__(exc_type, exc, tb)
|
|
|
|
await self.session.__aexit__(exc_type, exc, tb)
|
|
|
|
|
|
|
|
async def table_exists(self):
|
|
|
|
exists = await self.dbi.fetchval("""SELECT EXISTS (
|
|
|
|
SELECT 1
|
|
|
|
FROM information_schema.tables
|
|
|
|
WHERE table_schema = 'public'
|
|
|
|
AND table_name = $1
|
|
|
|
);""", self.table_name, column=0)
|
|
|
|
|
|
|
|
return bool(exists)
|
|
|
|
|
|
|
|
async def init_table(self):
|
|
|
|
if self.table_initialized:
|
|
|
|
return
|
|
|
|
|
|
|
|
# create table if not exists
|
|
|
|
if not await self.table_exists():
|
|
|
|
async with self.dbs.engine.begin() as conn:
|
|
|
|
await conn.run_sync(self.orm.__table__.create)
|
|
|
|
|
|
|
|
# await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ (
|
|
|
|
# id SERIAL PRIMARY KEY,
|
|
|
|
# page_id INTEGER NOT NULL,
|
|
|
|
# sha1 VARCHAR(40) NOT NULL,
|
|
|
|
# text TEXT NOT NULL,
|
|
|
|
# text_len INTEGER NOT NULL,
|
|
|
|
# embedding VECTOR(%d) NOT NULL,
|
|
|
|
# markdown TEXT NULL,
|
|
|
|
# markdown_len INTEGER NULL,
|
|
|
|
# temp_doc_session_id INTEGER NULL
|
|
|
|
# );
|
|
|
|
# CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id);
|
|
|
|
# CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1);
|
|
|
|
# CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id);
|
|
|
|
# """ % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name))
|
|
|
|
|
|
|
|
self.table_initialized = True
|
|
|
|
|
|
|
|
async def create_embedding_index(self):
|
|
|
|
pass
|
|
|
|
# await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);"
|
|
|
|
# .replace("/*_*/", self.table_name))
|
|
|
|
|
|
|
|
def sha1_doc(self, doc: list):
|
|
|
|
for item in doc:
|
|
|
|
if "sha1" not in item or not item["sha1"]:
|
|
|
|
sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest()
|
|
|
|
item["sha1"] = sha1
|
|
|
|
|
|
|
|
async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False):
|
|
|
|
indexed_sha1_list = []
|
|
|
|
|
|
|
|
stmt = select(self.orm).column(self.orm.sha1)
|
|
|
|
|
|
|
|
if not with_temporary:
|
|
|
|
stmt = stmt.where(self.orm.temp_doc_session_id == None)
|
|
|
|
|
|
|
|
if not in_collection:
|
|
|
|
stmt = stmt.where(self.orm.page_id == self.page_id)
|
|
|
|
|
|
|
|
ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt)
|
|
|
|
|
|
|
|
for row in ret:
|
|
|
|
indexed_sha1_list.append(row.sha1)
|
|
|
|
|
|
|
|
return indexed_sha1_list
|
|
|
|
|
|
|
|
async def get_unindexed_doc(self, doc: list, with_temporary: bool = True):
|
|
|
|
indexed_sha1_list = await self.get_indexed_sha1(with_temporary)
|
|
|
|
self.sha1_doc(doc)
|
|
|
|
|
|
|
|
should_index = []
|
|
|
|
for item in doc:
|
|
|
|
if item["sha1"] not in indexed_sha1_list:
|
|
|
|
should_index.append(item)
|
|
|
|
|
|
|
|
return should_index
|
|
|
|
|
|
|
|
async def remove_outdated_doc(self, doc: list):
|
|
|
|
await self.clear_temp()
|
|
|
|
|
|
|
|
indexed_sha1_list = await self.get_indexed_sha1(False)
|
|
|
|
self.sha1_doc(doc)
|
|
|
|
|
|
|
|
doc_sha1_list = [item["sha1"] for item in doc]
|
|
|
|
|
|
|
|
should_remove = []
|
|
|
|
for sha1 in indexed_sha1_list:
|
|
|
|
if sha1 not in doc_sha1_list:
|
|
|
|
should_remove.append(sha1)
|
|
|
|
|
|
|
|
if len(should_remove) > 0:
|
|
|
|
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)" % (self.table_name),
|
|
|
|
self.page_id, should_remove)
|
|
|
|
|
|
|
|
async def index_doc(self, doc: list):
|
|
|
|
need_create_index = False
|
|
|
|
|
|
|
|
indexed_persist_sha1_list = []
|
|
|
|
indexed_temp_sha1_list = []
|
|
|
|
|
|
|
|
ret = await self.dbi.fetch("SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1" % (self.table_name),
|
|
|
|
self.page_id)
|
|
|
|
for row in ret:
|
|
|
|
if row[1]:
|
|
|
|
indexed_temp_sha1_list.append(row[0])
|
|
|
|
else:
|
|
|
|
indexed_persist_sha1_list.append(row[0])
|
|
|
|
|
|
|
|
# Create index when no indexed document
|
|
|
|
if len(indexed_persist_sha1_list) == 0:
|
|
|
|
need_create_index = True
|
|
|
|
|
|
|
|
self.sha1_doc(doc)
|
|
|
|
|
|
|
|
doc_sha1_list = []
|
|
|
|
|
|
|
|
should_index = []
|
|
|
|
should_persist = []
|
|
|
|
should_remove = []
|
|
|
|
for item in doc:
|
|
|
|
doc_sha1_list.append(item["sha1"])
|
|
|
|
|
|
|
|
if item["sha1"] in indexed_temp_sha1_list:
|
|
|
|
should_persist.append(item["sha1"])
|
|
|
|
elif item["sha1"] not in indexed_persist_sha1_list:
|
|
|
|
should_index.append(item)
|
|
|
|
|
|
|
|
for sha1 in indexed_persist_sha1_list:
|
|
|
|
if sha1 not in doc_sha1_list:
|
|
|
|
should_remove.append(sha1)
|
|
|
|
|
|
|
|
if len(should_index) > 0:
|
|
|
|
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
|
|
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name),
|
|
|
|
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"]) for item in should_index])
|
|
|
|
|
|
|
|
if len(should_persist) > 0:
|
|
|
|
await self.dbi.executemany("UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name),
|
|
|
|
[(self.page_id, sha1) for sha1 in should_persist])
|
|
|
|
|
|
|
|
if need_create_index:
|
|
|
|
await self.create_embedding_index()
|
|
|
|
|
|
|
|
"""
|
|
|
|
Add temporary document to the index
|
|
|
|
"""
|
|
|
|
async def index_temp_doc(self, doc: list, temp_doc_session_id: int):
|
|
|
|
indexed_sha1_list = []
|
|
|
|
indexed_temp_sha1_list = []
|
|
|
|
doc_sha1_list = []
|
|
|
|
|
|
|
|
sql = "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1 AND (temp_doc_session_id = $2 OR temp_doc_session_id IS NULL)" % (
|
|
|
|
self.table_name)
|
|
|
|
ret = await self.dbi.fetch(sql, self.page_id, temp_doc_session_id)
|
|
|
|
for row in ret:
|
|
|
|
indexed_sha1_list.append(row[0])
|
|
|
|
if row[1]:
|
|
|
|
indexed_temp_sha1_list.append(row[0])
|
|
|
|
|
|
|
|
self.sha1_doc(doc)
|
|
|
|
|
|
|
|
should_index = []
|
|
|
|
should_remove = []
|
|
|
|
|
|
|
|
for item in doc:
|
|
|
|
sha1 = item["sha1"]
|
|
|
|
doc_sha1_list.append(sha1)
|
|
|
|
if sha1 not in indexed_sha1_list:
|
|
|
|
should_index.append(item)
|
|
|
|
|
|
|
|
for sha1 in indexed_temp_sha1_list:
|
|
|
|
if sha1 not in doc_sha1_list:
|
|
|
|
should_remove.append(sha1)
|
|
|
|
|
|
|
|
if len(should_index) > 0:
|
|
|
|
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
|
|
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8);""" % (self.table_name),
|
|
|
|
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"], temp_doc_session_id) for item in should_index])
|
|
|
|
|
|
|
|
if len(should_remove) > 0:
|
|
|
|
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND temp_doc_session_id = $2 AND sha1 = ANY($3)" % (self.table_name),
|
|
|
|
self.page_id, temp_doc_session_id, should_remove)
|
|
|
|
|
|
|
|
"""
|
|
|
|
Search for text by consine similary
|
|
|
|
"""
|
|
|
|
async def search_text_embedding(self, embedding: np.ndarray, in_collection: bool = False, limit: int = 10):
|
|
|
|
if in_collection:
|
|
|
|
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
|
|
|
|
FROM %s
|
|
|
|
ORDER BY distance ASC
|
|
|
|
LIMIT %d""" % (self.table_name, limit), embedding)
|
|
|
|
else:
|
|
|
|
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
|
|
|
|
FROM %s
|
|
|
|
WHERE page_id = $2
|
|
|
|
ORDER BY distance ASC
|
|
|
|
LIMIT %d""" % (self.table_name, limit), embedding, self.page_id)
|
|
|
|
|
|
|
|
"""
|
|
|
|
Clear temporary index
|
|
|
|
"""
|
|
|
|
async def clear_temp(self, in_collection: bool = False, temp_doc_session_id: int = None):
|
|
|
|
sql = "DELETE FROM %s" % (self.table_name)
|
|
|
|
|
|
|
|
where = []
|
|
|
|
params = []
|
|
|
|
|
|
|
|
if not in_collection:
|
|
|
|
params.append(self.page_id)
|
|
|
|
where.append("page_id = $%d" % len(params))
|
|
|
|
|
|
|
|
if temp_doc_session_id:
|
|
|
|
params.append(temp_doc_session_id)
|
|
|
|
where.append("temp_doc_session_id = $%d" % len(params))
|
|
|
|
else:
|
|
|
|
where.append("temp_doc_session_id IS NOT NULL")
|
|
|
|
|
|
|
|
if len(where) > 0:
|
|
|
|
sql += " WHERE " + (" AND ".join(where))
|
|
|
|
|
|
|
|
await self.dbi.execute(sql, *params)
|