You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

163 lines
5.5 KiB
Python

import hashlib
import asyncpg
import numpy as np
from pgvector.sqlalchemy import Vector
from pgvector.asyncpg import register_vector
import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped
from sqlalchemy.ext.asyncio import AsyncEngine
import config
from api.model.base import BaseModel
from service.database import DatabaseService
class TitleIndexModel(BaseModel):
__tablename__ = "embedding_search_title_index"
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
sha1: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding,
postgresql_using='ivfflat',
postgresql_ops={'embedding': 'vector_cosine_ops'})
class TitleIndexHelper:
__tablename__ = "embedding_search_title_index"
columns = [
"id",
"sha1",
"title",
"page_id",
"collection_id",
"rev_id",
"embedding",
]
def __init__(self, dbs: DatabaseService):
self.dbs = dbs
self.initialized = False
async def __aenter__(self):
if not self.initialized:
self.dbpool = self.dbs.pool.acquire()
self.dbi = await self.dbpool.__aenter__()
await register_vector(self.dbi)
self.initialized = True
return self
async def __aexit__(self, exc_type, exc, tb):
await self.dbpool.__aexit__(exc_type, exc, tb)
def get_columns(self, exclude=[]):
if len(exclude) == 0:
return ", ".join(self.columns)
return ", ".join([col for col in self.columns if col not in exclude])
"""
Add a title to the index
"""
async def add(self, title: str, page_id: int, rev_id: int, collection_id: int, embedding: np.ndarray):
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
ret = await self.dbi.fetchrow("SELECT * FROM embedding_search_title_index WHERE sha1 = $1", title_sha1)
if ret is None:
new_id = await self.dbi.fetchval("""INSERT INTO embedding_search_title_index
(sha1, title, page_id, rev_id, collection_id, embedding)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id""",
title_sha1, title, page_id, rev_id, collection_id, embedding, column=0)
return new_id
return False
"""
Remove a title from the index
"""
async def remove(self, title: str):
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
await self.dbi.execute("DELETE FROM embedding_search_title_index WHERE sha1 = $1", title_sha1)
"""
Update the indexed revision id of a title
"""
async def update_rev_id(self, page_id: int, rev_id: int):
await self.dbi.execute("UPDATE embedding_search_title_index SET rev_id = $1 WHERE page_id = $2", rev_id, page_id)
"""
Update title data
"""
async def update_title_data(self, page_id: int, title: str, rev_id: int, collection_id: int, embedding: np.ndarray):
if collection_page_id is None:
collection_page_id = page_id
await self.dbi.execute("""UPDATE embedding_search_title_index
SET title = $1, rev_id = $2, collection_id = $3, embedding = $4
WHERE page_id = $5""",
title, rev_id, collection_id, embedding, page_id)
"""
Search for titles by consine similary
"""
async def search_title_embedding(self, embedding: np.ndarray, limit: int = 10):
ret = self.dbi.fetch("""SELECT %s, embedding <-> $1 AS distance
FROM embedding_search_title_index
ORDER BY distance DESC
LIMIT %d""" % (self.get_columns(exclude=['embedding']), limit),
embedding)
return ret
"""
Find a title in the index
"""
async def find_by_title(self, title: str, with_embedding=False):
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
if with_embedding:
columns = self.get_columns()
else:
columns = self.get_columns(exclude=["embedding"])
ret = await self.dbi.fetchrow(
"SELECT %s FROM embedding_search_title_index WHERE sha1 = $1" % columns,
title_sha1
)
return ret
async def find_by_page_id(self, page_id: int, with_embedding=False):
if with_embedding:
columns = self.get_columns()
else:
columns = self.get_columns(exclude=["embedding"])
ret = await self.dbi.fetchrow(
"SELECT %s FROM embedding_search_title_index WHERE page_id = $1" % columns,
page_id
)
return ret
async def find_by_collection_id(self, collection_id: int, with_embedding=False):
if with_embedding:
columns = self.get_columns()
else:
columns = self.get_columns(exclude=["embedding"])
ret = await self.dbi.fetch(
"SELECT %s FROM embedding_search_title_index WHERE collection_id = $1" % columns,
collection_id
)
return ret