|
|
|
@ -21,7 +21,9 @@ page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}
|
|
|
|
|
class AbstractPageIndexModel(BaseModel):
|
|
|
|
|
__abstract__ = True
|
|
|
|
|
|
|
|
|
|
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
|
|
|
|
id: Mapped[int] = mapped_column(
|
|
|
|
|
sqlalchemy.Integer, primary_key=True, autoincrement=True
|
|
|
|
|
)
|
|
|
|
|
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
|
|
|
|
sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True)
|
|
|
|
|
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
|
|
|
|
@ -36,34 +38,37 @@ def create_page_index_model(collection_id: int):
|
|
|
|
|
if collection_id in page_index_model_list:
|
|
|
|
|
return page_index_model_list[collection_id]
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
class _PageIndexModel(AbstractPageIndexModel):
|
|
|
|
|
__tablename__ = "embedding_search_page_index_%s" % str(collection_id)
|
|
|
|
|
|
|
|
|
|
embedding_index = sqlalchemy.Index(__tablename__ + "_embedding_idx", AbstractPageIndexModel.embedding,
|
|
|
|
|
postgresql_using='ivfflat',
|
|
|
|
|
postgresql_ops={'embedding': 'vector_cosine_ops'})
|
|
|
|
|
|
|
|
|
|
embedding_index = sqlalchemy.Index(
|
|
|
|
|
__tablename__ + "_embedding_idx",
|
|
|
|
|
AbstractPageIndexModel.embedding,
|
|
|
|
|
postgresql_using="ivfflat",
|
|
|
|
|
postgresql_ops={"embedding": "vector_cosine_ops"},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
page_index_model_list[collection_id] = _PageIndexModel
|
|
|
|
|
|
|
|
|
|
return _PageIndexModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PageIndexHelper:
|
|
|
|
|
columns = [
|
|
|
|
|
"id",
|
|
|
|
|
"page_id"
|
|
|
|
|
"sha1",
|
|
|
|
|
"page_id" "sha1",
|
|
|
|
|
"text",
|
|
|
|
|
"text_len",
|
|
|
|
|
"markdown",
|
|
|
|
|
"markdown_len",
|
|
|
|
|
"embedding",
|
|
|
|
|
"temp_doc_session_id"
|
|
|
|
|
"temp_doc_session_id",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def __init__(self, dbs: DatabaseService, collection_id: int, page_id: Optional[int]):
|
|
|
|
|
def __init__(self, dbs: DatabaseService, collection_id: int):
|
|
|
|
|
self.dbs = dbs
|
|
|
|
|
self.collection_id = collection_id
|
|
|
|
|
self.page_id = page_id if page_id is not None else -1
|
|
|
|
|
self.table_name = "embedding_search_page_index_%s" % str(collection_id)
|
|
|
|
|
self.initialized = False
|
|
|
|
|
self.table_initialized = False
|
|
|
|
@ -71,10 +76,11 @@ class PageIndexHelper:
|
|
|
|
|
"""
|
|
|
|
|
Initialize table
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
|
|
if self.initialized:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.dbpool = self.dbs.pool.acquire()
|
|
|
|
|
self.dbi = await self.dbpool.__aenter__()
|
|
|
|
|
|
|
|
|
@ -95,12 +101,16 @@ class PageIndexHelper:
|
|
|
|
|
await self.session.__aexit__(exc_type, exc, tb)
|
|
|
|
|
|
|
|
|
|
async def table_exists(self):
|
|
|
|
|
exists = await self.dbi.fetchval("""SELECT EXISTS (
|
|
|
|
|
exists = await self.dbi.fetchval(
|
|
|
|
|
"""SELECT EXISTS (
|
|
|
|
|
SELECT 1
|
|
|
|
|
FROM information_schema.tables
|
|
|
|
|
WHERE table_schema = 'public'
|
|
|
|
|
AND table_name = $1
|
|
|
|
|
);""", self.table_name, column=0)
|
|
|
|
|
);""",
|
|
|
|
|
self.table_name,
|
|
|
|
|
column=0,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return bool(exists)
|
|
|
|
|
|
|
|
|
@ -113,22 +123,6 @@ class PageIndexHelper:
|
|
|
|
|
async with self.dbs.engine.begin() as conn:
|
|
|
|
|
await conn.run_sync(self.orm.__table__.create)
|
|
|
|
|
|
|
|
|
|
# await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ (
|
|
|
|
|
# id SERIAL PRIMARY KEY,
|
|
|
|
|
# page_id INTEGER NOT NULL,
|
|
|
|
|
# sha1 VARCHAR(40) NOT NULL,
|
|
|
|
|
# text TEXT NOT NULL,
|
|
|
|
|
# text_len INTEGER NOT NULL,
|
|
|
|
|
# embedding VECTOR(%d) NOT NULL,
|
|
|
|
|
# markdown TEXT NULL,
|
|
|
|
|
# markdown_len INTEGER NULL,
|
|
|
|
|
# temp_doc_session_id INTEGER NULL
|
|
|
|
|
# );
|
|
|
|
|
# CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id);
|
|
|
|
|
# CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1);
|
|
|
|
|
# CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id);
|
|
|
|
|
# """ % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name))
|
|
|
|
|
|
|
|
|
|
self.table_initialized = True
|
|
|
|
|
|
|
|
|
|
async def create_embedding_index(self):
|
|
|
|
@ -142,7 +136,9 @@ class PageIndexHelper:
|
|
|
|
|
sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest()
|
|
|
|
|
item["sha1"] = sha1
|
|
|
|
|
|
|
|
|
|
async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False):
|
|
|
|
|
async def get_indexed_sha1(
|
|
|
|
|
self, page_id: int, with_temporary: bool = True, in_collection: bool = False
|
|
|
|
|
):
|
|
|
|
|
indexed_sha1_list = []
|
|
|
|
|
|
|
|
|
|
stmt = select(self.orm).column(self.orm.sha1)
|
|
|
|
@ -151,7 +147,7 @@ class PageIndexHelper:
|
|
|
|
|
stmt = stmt.where(self.orm.temp_doc_session_id == None)
|
|
|
|
|
|
|
|
|
|
if not in_collection:
|
|
|
|
|
stmt = stmt.where(self.orm.page_id == self.page_id)
|
|
|
|
|
stmt = stmt.where(self.orm.page_id == page_id)
|
|
|
|
|
|
|
|
|
|
ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt)
|
|
|
|
|
|
|
|
|
@ -160,8 +156,10 @@ class PageIndexHelper:
|
|
|
|
|
|
|
|
|
|
return indexed_sha1_list
|
|
|
|
|
|
|
|
|
|
async def get_unindexed_doc(self, doc: list, with_temporary: bool = True):
|
|
|
|
|
indexed_sha1_list = await self.get_indexed_sha1(with_temporary)
|
|
|
|
|
async def get_unindexed_doc(
|
|
|
|
|
self, doc: list, page_id: int, with_temporary: bool = True
|
|
|
|
|
):
|
|
|
|
|
indexed_sha1_list = await self.get_indexed_sha1(page_id, with_temporary)
|
|
|
|
|
self.sha1_doc(doc)
|
|
|
|
|
|
|
|
|
|
should_index = []
|
|
|
|
@ -171,10 +169,10 @@ class PageIndexHelper:
|
|
|
|
|
|
|
|
|
|
return should_index
|
|
|
|
|
|
|
|
|
|
async def remove_outdated_doc(self, doc: list):
|
|
|
|
|
await self.clear_temp()
|
|
|
|
|
async def remove_outdated_doc(self, doc: list, page_id: int):
|
|
|
|
|
await self.clear_temp(page_id=page_id)
|
|
|
|
|
|
|
|
|
|
indexed_sha1_list = await self.get_indexed_sha1(False)
|
|
|
|
|
indexed_sha1_list = await self.get_indexed_sha1(page_id, False)
|
|
|
|
|
self.sha1_doc(doc)
|
|
|
|
|
|
|
|
|
|
doc_sha1_list = [item["sha1"] for item in doc]
|
|
|
|
@ -185,17 +183,24 @@ class PageIndexHelper:
|
|
|
|
|
should_remove.append(sha1)
|
|
|
|
|
|
|
|
|
|
if len(should_remove) > 0:
|
|
|
|
|
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)" % (self.table_name),
|
|
|
|
|
self.page_id, should_remove)
|
|
|
|
|
|
|
|
|
|
async def index_doc(self, doc: list):
|
|
|
|
|
await self.dbi.execute(
|
|
|
|
|
"DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)"
|
|
|
|
|
% (self.table_name),
|
|
|
|
|
page_id,
|
|
|
|
|
should_remove,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def index_doc(self, doc: list, page_id: int):
|
|
|
|
|
need_create_index = False
|
|
|
|
|
|
|
|
|
|
indexed_persist_sha1_list = []
|
|
|
|
|
indexed_temp_sha1_list = []
|
|
|
|
|
|
|
|
|
|
ret = await self.dbi.fetch("SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1" % (self.table_name),
|
|
|
|
|
self.page_id)
|
|
|
|
|
ret = await self.dbi.fetch(
|
|
|
|
|
"SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1"
|
|
|
|
|
% (self.table_name),
|
|
|
|
|
page_id,
|
|
|
|
|
)
|
|
|
|
|
for row in ret:
|
|
|
|
|
if row[1]:
|
|
|
|
|
indexed_temp_sha1_list.append(row[0])
|
|
|
|
@ -226,28 +231,48 @@ class PageIndexHelper:
|
|
|
|
|
should_remove.append(sha1)
|
|
|
|
|
|
|
|
|
|
if len(should_index) > 0:
|
|
|
|
|
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
|
|
|
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name),
|
|
|
|
|
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"]) for item in should_index])
|
|
|
|
|
await self.dbi.executemany(
|
|
|
|
|
"""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
|
|
|
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);"""
|
|
|
|
|
% (self.table_name),
|
|
|
|
|
[
|
|
|
|
|
(
|
|
|
|
|
item["sha1"],
|
|
|
|
|
page_id,
|
|
|
|
|
item["text"],
|
|
|
|
|
len(item["text"]),
|
|
|
|
|
item["markdown"],
|
|
|
|
|
len(item["markdown"]),
|
|
|
|
|
item["embedding"],
|
|
|
|
|
)
|
|
|
|
|
for item in should_index
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if len(should_persist) > 0:
|
|
|
|
|
await self.dbi.executemany("UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name),
|
|
|
|
|
[(self.page_id, sha1) for sha1 in should_persist])
|
|
|
|
|
|
|
|
|
|
await self.dbi.executemany(
|
|
|
|
|
"UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2"
|
|
|
|
|
% (self.table_name),
|
|
|
|
|
[(page_id, sha1) for sha1 in should_persist],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if need_create_index:
|
|
|
|
|
await self.create_embedding_index()
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Add temporary document to the index
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
async def index_temp_doc(self, doc: list, temp_doc_session_id: int):
|
|
|
|
|
indexed_sha1_list = []
|
|
|
|
|
indexed_temp_sha1_list = []
|
|
|
|
|
doc_sha1_list = []
|
|
|
|
|
|
|
|
|
|
sql = "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1 AND (temp_doc_session_id = $2 OR temp_doc_session_id IS NULL)" % (
|
|
|
|
|
self.table_name)
|
|
|
|
|
ret = await self.dbi.fetch(sql, self.page_id, temp_doc_session_id)
|
|
|
|
|
sql = (
|
|
|
|
|
"SELECT sha1, temp_doc_session_id FROM %s WHERE (temp_doc_session_id = $1 OR temp_doc_session_id IS NULL)"
|
|
|
|
|
% (self.table_name)
|
|
|
|
|
)
|
|
|
|
|
ret = await self.dbi.fetch(sql, temp_doc_session_id)
|
|
|
|
|
for row in ret:
|
|
|
|
|
indexed_sha1_list.append(row[0])
|
|
|
|
|
if row[1]:
|
|
|
|
@ -269,41 +294,79 @@ class PageIndexHelper:
|
|
|
|
|
should_remove.append(sha1)
|
|
|
|
|
|
|
|
|
|
if len(should_index) > 0:
|
|
|
|
|
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
|
|
|
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8);""" % (self.table_name),
|
|
|
|
|
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"], temp_doc_session_id) for item in should_index])
|
|
|
|
|
await self.dbi.executemany(
|
|
|
|
|
"""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
|
|
|
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8);"""
|
|
|
|
|
% (self.table_name),
|
|
|
|
|
[
|
|
|
|
|
(
|
|
|
|
|
item["sha1"],
|
|
|
|
|
item["page_id"],
|
|
|
|
|
item["text"],
|
|
|
|
|
len(item["text"]),
|
|
|
|
|
item["markdown"],
|
|
|
|
|
len(item["markdown"]),
|
|
|
|
|
item["embedding"],
|
|
|
|
|
temp_doc_session_id,
|
|
|
|
|
)
|
|
|
|
|
for item in should_index
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if len(should_remove) > 0:
|
|
|
|
|
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND temp_doc_session_id = $2 AND sha1 = ANY($3)" % (self.table_name),
|
|
|
|
|
self.page_id, temp_doc_session_id, should_remove)
|
|
|
|
|
await self.dbi.execute(
|
|
|
|
|
"DELETE FROM %s WHERE temp_doc_session_id = $1 AND sha1 = ANY($2)"
|
|
|
|
|
% (self.table_name),
|
|
|
|
|
temp_doc_session_id,
|
|
|
|
|
should_remove,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Search for text by consine similary
|
|
|
|
|
"""
|
|
|
|
|
async def search_text_embedding(self, embedding: np.ndarray, in_collection: bool = False, limit: int = 10):
|
|
|
|
|
|
|
|
|
|
async def search_text_embedding(
|
|
|
|
|
self,
|
|
|
|
|
embedding: np.ndarray,
|
|
|
|
|
in_collection: bool = False,
|
|
|
|
|
limit: int = 10,
|
|
|
|
|
page_id: Optional[int] = None,
|
|
|
|
|
):
|
|
|
|
|
if in_collection:
|
|
|
|
|
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
|
|
|
|
|
return await self.dbi.fetch(
|
|
|
|
|
"""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
|
|
|
|
|
FROM %s
|
|
|
|
|
ORDER BY distance ASC
|
|
|
|
|
LIMIT %d""" % (self.table_name, limit), embedding)
|
|
|
|
|
LIMIT %d"""
|
|
|
|
|
% (self.table_name, limit),
|
|
|
|
|
embedding,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
|
|
|
|
|
return await self.dbi.fetch(
|
|
|
|
|
"""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
|
|
|
|
|
FROM %s
|
|
|
|
|
WHERE page_id = $2
|
|
|
|
|
ORDER BY distance ASC
|
|
|
|
|
LIMIT %d""" % (self.table_name, limit), embedding, self.page_id)
|
|
|
|
|
LIMIT %d"""
|
|
|
|
|
% (self.table_name, limit),
|
|
|
|
|
embedding,
|
|
|
|
|
page_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Clear temporary index
|
|
|
|
|
"""
|
|
|
|
|
async def clear_temp(self, in_collection: bool = False, temp_doc_session_id: int = None):
|
|
|
|
|
|
|
|
|
|
async def clear_temp(
|
|
|
|
|
self, in_collection: bool = False, temp_doc_session_id: int = None, page_id: Optional[int] = None
|
|
|
|
|
):
|
|
|
|
|
sql = "DELETE FROM %s" % (self.table_name)
|
|
|
|
|
|
|
|
|
|
where = []
|
|
|
|
|
params = []
|
|
|
|
|
|
|
|
|
|
if not in_collection:
|
|
|
|
|
params.append(self.page_id)
|
|
|
|
|
if not in_collection and page_id:
|
|
|
|
|
params.append(page_id)
|
|
|
|
|
where.append("page_id = $%d" % len(params))
|
|
|
|
|
|
|
|
|
|
if temp_doc_session_id:
|
|
|
|
@ -316,3 +379,8 @@ class PageIndexHelper:
|
|
|
|
|
sql += " WHERE " + (" AND ".join(where))
|
|
|
|
|
|
|
|
|
|
await self.dbi.execute(sql, *params)
|
|
|
|
|
|
|
|
|
|
async def remove_by_page_id(self, page_id: int):
|
|
|
|
|
stmt = delete(self.orm).where(self.orm.page_id == page_id)
|
|
|
|
|
await self.session.execute(stmt)
|
|
|
|
|
await self.session.commit()
|
|
|
|
|