|
|
|
@ -1,14 +1,15 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
import hashlib
|
|
|
|
|
import asyncpg
|
|
|
|
|
import numpy as np
|
|
|
|
|
from pgvector.sqlalchemy import Vector
|
|
|
|
|
from pgvector.asyncpg import register_vector
|
|
|
|
|
import sqlalchemy
|
|
|
|
|
from sqlalchemy.orm import mapped_column, relationship, Mapped
|
|
|
|
|
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
|
|
|
|
|
|
|
|
import config
|
|
|
|
|
from api.model.base import BaseModel
|
|
|
|
|
from api.model.base import BaseHelper, BaseModel
|
|
|
|
|
from service.database import DatabaseService
|
|
|
|
|
|
|
|
|
|
class TitleIndexModel(BaseModel):
|
|
|
|
@ -19,14 +20,15 @@ class TitleIndexModel(BaseModel):
|
|
|
|
|
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
|
|
|
|
|
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
|
|
|
|
collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
|
|
|
|
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
|
|
|
|
|
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
|
|
|
|
|
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
|
|
|
|
|
latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
|
|
|
|
|
embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)))
|
|
|
|
|
|
|
|
|
|
embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding,
|
|
|
|
|
postgresql_using='ivfflat',
|
|
|
|
|
postgresql_ops={'embedding': 'vector_cosine_ops'})
|
|
|
|
|
|
|
|
|
|
class TitleIndexHelper:
|
|
|
|
|
class TitleIndexHelper(BaseHelper):
|
|
|
|
|
__tablename__ = "embedding_search_title_index"
|
|
|
|
|
|
|
|
|
|
columns = [
|
|
|
|
@ -36,12 +38,12 @@ class TitleIndexHelper:
|
|
|
|
|
"page_id",
|
|
|
|
|
"collection_id",
|
|
|
|
|
"rev_id",
|
|
|
|
|
"latest_rev_id"
|
|
|
|
|
"embedding",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def __init__(self, dbs: DatabaseService):
|
|
|
|
|
self.dbs = dbs
|
|
|
|
|
self.initialized = False
|
|
|
|
|
super().__init__(dbs)
|
|
|
|
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
|
|
if not self.initialized:
|
|
|
|
@ -50,12 +52,11 @@ class TitleIndexHelper:
|
|
|
|
|
|
|
|
|
|
await register_vector(self.dbi)
|
|
|
|
|
|
|
|
|
|
self.initialized = True
|
|
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
await super().__aenter__()
|
|
|
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
|
|
|
await self.dbpool.__aexit__(exc_type, exc, tb)
|
|
|
|
|
await super().__aexit__(exc_type, exc, tb)
|
|
|
|
|
|
|
|
|
|
def get_columns(self, exclude=[]):
|
|
|
|
|
if len(exclude) == 0:
|
|
|
|
@ -66,17 +67,14 @@ class TitleIndexHelper:
|
|
|
|
|
"""
|
|
|
|
|
Add a title to the index
|
|
|
|
|
"""
|
|
|
|
|
async def add(self, title: str, page_id: int, rev_id: int, collection_id: int, embedding: np.ndarray):
|
|
|
|
|
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
|
|
|
|
|
ret = await self.dbi.fetchrow("SELECT * FROM embedding_search_title_index WHERE sha1 = $1", title_sha1)
|
|
|
|
|
|
|
|
|
|
if ret is None:
|
|
|
|
|
new_id = await self.dbi.fetchval("""INSERT INTO embedding_search_title_index
|
|
|
|
|
(sha1, title, page_id, rev_id, collection_id, embedding)
|
|
|
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
|
|
|
RETURNING id""",
|
|
|
|
|
title_sha1, title, page_id, rev_id, collection_id, embedding, column=0)
|
|
|
|
|
return new_id
|
|
|
|
|
async def add(self, obj: TitleIndexModel):
|
|
|
|
|
obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest()
|
|
|
|
|
|
|
|
|
|
if await self.find_by_sha1(obj.sha1) is None:
|
|
|
|
|
self.session.add(obj)
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
await self.session.refresh(obj)
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
@ -85,25 +83,15 @@ class TitleIndexHelper:
|
|
|
|
|
"""
|
|
|
|
|
async def remove(self, title: str):
|
|
|
|
|
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
|
|
|
|
|
await self.dbi.execute("DELETE FROM embedding_search_title_index WHERE sha1 = $1", title_sha1)
|
|
|
|
|
stmt = sqlalchemy.delete(TitleIndexModel).where(TitleIndexModel.sha1 == title_sha1)
|
|
|
|
|
await self.session.execute(stmt)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Update the indexed revision id of a title
|
|
|
|
|
"""
|
|
|
|
|
async def update_rev_id(self, page_id: int, rev_id: int):
|
|
|
|
|
await self.dbi.execute("UPDATE embedding_search_title_index SET rev_id = $1 WHERE page_id = $2", rev_id, page_id)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Update title data
|
|
|
|
|
"""
|
|
|
|
|
async def update_title_data(self, page_id: int, title: str, rev_id: int, collection_id: int, embedding: np.ndarray):
|
|
|
|
|
if collection_page_id is None:
|
|
|
|
|
collection_page_id = page_id
|
|
|
|
|
|
|
|
|
|
await self.dbi.execute("""UPDATE embedding_search_title_index
|
|
|
|
|
SET title = $1, rev_id = $2, collection_id = $3, embedding = $4
|
|
|
|
|
WHERE page_id = $5""",
|
|
|
|
|
title, rev_id, collection_id, embedding, page_id)
|
|
|
|
|
async def update(self, obj: TitleIndexModel):
|
|
|
|
|
obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest()
|
|
|
|
|
await self.session.merge(obj)
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
await self.session.refresh(obj)
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Search for titles by consine similary
|
|
|
|
@ -120,43 +108,18 @@ class TitleIndexHelper:
|
|
|
|
|
"""
|
|
|
|
|
Find a title in the index
|
|
|
|
|
"""
|
|
|
|
|
async def find_by_title(self, title: str, with_embedding=False):
|
|
|
|
|
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
|
|
|
|
|
|
|
|
|
|
if with_embedding:
|
|
|
|
|
columns = self.get_columns()
|
|
|
|
|
else:
|
|
|
|
|
columns = self.get_columns(exclude=["embedding"])
|
|
|
|
|
|
|
|
|
|
ret = await self.dbi.fetchrow(
|
|
|
|
|
"SELECT %s FROM embedding_search_title_index WHERE sha1 = $1" % columns,
|
|
|
|
|
title_sha1
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
async def find_by_page_id(self, page_id: int, with_embedding=False):
|
|
|
|
|
if with_embedding:
|
|
|
|
|
columns = self.get_columns()
|
|
|
|
|
else:
|
|
|
|
|
columns = self.get_columns(exclude=["embedding"])
|
|
|
|
|
|
|
|
|
|
ret = await self.dbi.fetchrow(
|
|
|
|
|
"SELECT %s FROM embedding_search_title_index WHERE page_id = $1" % columns,
|
|
|
|
|
page_id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
async def find_by_collection_id(self, collection_id: int, with_embedding=False):
|
|
|
|
|
if with_embedding:
|
|
|
|
|
columns = self.get_columns()
|
|
|
|
|
else:
|
|
|
|
|
columns = self.get_columns(exclude=["embedding"])
|
|
|
|
|
|
|
|
|
|
ret = await self.dbi.fetch(
|
|
|
|
|
"SELECT %s FROM embedding_search_title_index WHERE collection_id = $1" % columns,
|
|
|
|
|
collection_id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return ret
|
|
|
|
|
async def find_by_title(self, title: str):
|
|
|
|
|
sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
|
|
|
|
|
return await self.find_by_sha1(sha1)
|
|
|
|
|
|
|
|
|
|
async def find_by_sha1(self, sha1: str):
|
|
|
|
|
stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.sha1 == sha1)
|
|
|
|
|
return await self.session.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
async def find_by_page_id(self, page_id: int):
|
|
|
|
|
stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.page_id == page_id)
|
|
|
|
|
return await self.session.scalar(stmt)
|
|
|
|
|
|
|
|
|
|
async def find_list_by_collection_id(self, collection_id: int):
|
|
|
|
|
stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.collection_id == collection_id)
|
|
|
|
|
return await self.session.scalars(stmt)
|