You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

319 lines
12 KiB
Python

from __future__ import annotations
import hashlib
from typing import Optional, Type
import asyncpg
from api.model.base import BaseModel
import config
import numpy as np
import sqlalchemy
from sqlalchemy import Index, select, update, delete, Select
from sqlalchemy.orm import mapped_column, Mapped
from sqlalchemy.ext.asyncio import AsyncSession
from pgvector.asyncpg import register_vector
from pgvector.sqlalchemy import Vector
from service.database import DatabaseService
page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}
class AbstractPageIndexModel(BaseModel):
__abstract__ = True
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
text: Mapped[str] = mapped_column(sqlalchemy.Text)
text_len: Mapped[int] = mapped_column(sqlalchemy.Integer)
markdown: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=True)
markdown_len: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
temp_doc_session_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
def create_page_index_model(collection_id: int):
if collection_id in page_index_model_list:
return page_index_model_list[collection_id]
else:
class _PageIndexModel(AbstractPageIndexModel):
__tablename__ = "embedding_search_page_index_%s" % str(collection_id)
embedding_index = sqlalchemy.Index(__tablename__ + "_embedding_idx", AbstractPageIndexModel.embedding,
postgresql_using='ivfflat',
postgresql_ops={'embedding': 'vector_cosine_ops'})
page_index_model_list[collection_id] = _PageIndexModel
return _PageIndexModel
class PageIndexHelper:
columns = [
"id",
"page_id"
"sha1",
"text",
"text_len",
"markdown",
"markdown_len",
"embedding",
"temp_doc_session_id"
]
def __init__(self, dbs: DatabaseService, collection_id: int, page_id: Optional[int]):
self.dbs = dbs
self.collection_id = collection_id
self.page_id = page_id if page_id is not None else -1
self.table_name = "embedding_search_page_index_%s" % str(collection_id)
self.initialized = False
self.table_initialized = False
"""
Initialize table
"""
async def __aenter__(self):
if self.initialized:
return
self.dbpool = self.dbs.pool.acquire()
self.dbi = await self.dbpool.__aenter__()
await register_vector(self.dbi)
self.create_session = self.dbs.create_session
self.session = self.dbs.create_session()
await self.session.__aenter__()
self.orm = create_page_index_model(self.collection_id)
self.initialized = True
return self
async def __aexit__(self, exc_type, exc, tb):
await self.dbpool.__aexit__(exc_type, exc, tb)
await self.session.__aexit__(exc_type, exc, tb)
async def table_exists(self):
exists = await self.dbi.fetchval("""SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = $1
);""", self.table_name, column=0)
return bool(exists)
async def init_table(self):
if self.table_initialized:
return
# create table if not exists
if not await self.table_exists():
async with self.dbs.engine.begin() as conn:
await conn.run_sync(self.orm.__table__.create)
# await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ (
# id SERIAL PRIMARY KEY,
# page_id INTEGER NOT NULL,
# sha1 VARCHAR(40) NOT NULL,
# text TEXT NOT NULL,
# text_len INTEGER NOT NULL,
# embedding VECTOR(%d) NOT NULL,
# markdown TEXT NULL,
# markdown_len INTEGER NULL,
# temp_doc_session_id INTEGER NULL
# );
# CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id);
# CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1);
# CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id);
# """ % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name))
self.table_initialized = True
async def create_embedding_index(self):
pass
# await self.dbi.execute("CREATE INDEX IF NOT EXISTS /*_*/_embedding_idx ON /*_*/ USING ivfflat (embedding vector_cosine_ops);"
# .replace("/*_*/", self.table_name))
def sha1_doc(self, doc: list):
for item in doc:
if "sha1" not in item or not item["sha1"]:
sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest()
item["sha1"] = sha1
async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False):
indexed_sha1_list = []
stmt = select(self.orm).column(self.orm.sha1)
if not with_temporary:
stmt = stmt.where(self.orm.temp_doc_session_id == None)
if not in_collection:
stmt = stmt.where(self.orm.page_id == self.page_id)
ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt)
for row in ret:
indexed_sha1_list.append(row.sha1)
return indexed_sha1_list
async def get_unindexed_doc(self, doc: list, with_temporary: bool = True):
indexed_sha1_list = await self.get_indexed_sha1(with_temporary)
self.sha1_doc(doc)
should_index = []
for item in doc:
if item["sha1"] not in indexed_sha1_list:
should_index.append(item)
return should_index
async def remove_outdated_doc(self, doc: list):
await self.clear_temp()
indexed_sha1_list = await self.get_indexed_sha1(False)
self.sha1_doc(doc)
doc_sha1_list = [item["sha1"] for item in doc]
should_remove = []
for sha1 in indexed_sha1_list:
if sha1 not in doc_sha1_list:
should_remove.append(sha1)
if len(should_remove) > 0:
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)" % (self.table_name),
self.page_id, should_remove)
async def index_doc(self, doc: list):
need_create_index = False
indexed_persist_sha1_list = []
indexed_temp_sha1_list = []
ret = await self.dbi.fetch("SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1" % (self.table_name),
self.page_id)
for row in ret:
if row[1]:
indexed_temp_sha1_list.append(row[0])
else:
indexed_persist_sha1_list.append(row[0])
# Create index when no indexed document
if len(indexed_persist_sha1_list) == 0:
need_create_index = True
self.sha1_doc(doc)
doc_sha1_list = []
should_index = []
should_persist = []
should_remove = []
for item in doc:
doc_sha1_list.append(item["sha1"])
if item["sha1"] in indexed_temp_sha1_list:
should_persist.append(item["sha1"])
elif item["sha1"] not in indexed_persist_sha1_list:
should_index.append(item)
for sha1 in indexed_persist_sha1_list:
if sha1 not in doc_sha1_list:
should_remove.append(sha1)
if len(should_index) > 0:
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name),
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"]) for item in should_index])
if len(should_persist) > 0:
await self.dbi.executemany("UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name),
[(self.page_id, sha1) for sha1 in should_persist])
if need_create_index:
await self.create_embedding_index()
"""
Add temporary document to the index
"""
async def index_temp_doc(self, doc: list, temp_doc_session_id: int):
indexed_sha1_list = []
indexed_temp_sha1_list = []
doc_sha1_list = []
sql = "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1 AND (temp_doc_session_id = $2 OR temp_doc_session_id IS NULL)" % (
self.table_name)
ret = await self.dbi.fetch(sql, self.page_id, temp_doc_session_id)
for row in ret:
indexed_sha1_list.append(row[0])
if row[1]:
indexed_temp_sha1_list.append(row[0])
self.sha1_doc(doc)
should_index = []
should_remove = []
for item in doc:
sha1 = item["sha1"]
doc_sha1_list.append(sha1)
if sha1 not in indexed_sha1_list:
should_index.append(item)
for sha1 in indexed_temp_sha1_list:
if sha1 not in doc_sha1_list:
should_remove.append(sha1)
if len(should_index) > 0:
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8);""" % (self.table_name),
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"], temp_doc_session_id) for item in should_index])
if len(should_remove) > 0:
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND temp_doc_session_id = $2 AND sha1 = ANY($3)" % (self.table_name),
self.page_id, temp_doc_session_id, should_remove)
"""
Search for text by consine similary
"""
async def search_text_embedding(self, embedding: np.ndarray, in_collection: bool = False, limit: int = 10):
if in_collection:
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
FROM %s
ORDER BY distance ASC
LIMIT %d""" % (self.table_name, limit), embedding)
else:
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
FROM %s
WHERE page_id = $2
ORDER BY distance ASC
LIMIT %d""" % (self.table_name, limit), embedding, self.page_id)
"""
Clear temporary index
"""
async def clear_temp(self, in_collection: bool = False, temp_doc_session_id: int = None):
sql = "DELETE FROM %s" % (self.table_name)
where = []
params = []
if not in_collection:
params.append(self.page_id)
where.append("page_id = $%d" % len(params))
if temp_doc_session_id:
params.append(temp_doc_session_id)
where.append("temp_doc_session_id = $%d" % len(params))
else:
where.append("temp_doc_session_id IS NOT NULL")
if len(where) > 0:
sql += " WHERE " + (" AND ".join(where))
await self.dbi.execute(sql, *params)