You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
307 lines
11 KiB
Python
307 lines
11 KiB
Python
from __future__ import annotations
|
|
from typing import Optional, TypedDict
|
|
|
|
import sqlalchemy
|
|
from api.model.embedding_search.title_collection import (
|
|
TitleCollectionHelper,
|
|
TitleCollectionModel,
|
|
)
|
|
from api.model.embedding_search.title_index import TitleIndexHelper, TitleIndexModel
|
|
from api.model.embedding_search.page_index import PageIndexHelper
|
|
from service.database import DatabaseService
|
|
from service.mediawiki_api import MediaWikiApi
|
|
from service.openai_api import OpenAIApi
|
|
from service.tiktoken import TikTokenService
|
|
from utils.wiki import getWikiSentences
|
|
|
|
|
|
class EmbeddingRunningException(Exception):
|
|
pass
|
|
|
|
|
|
class EmbeddingSearchArgs(TypedDict):
|
|
limit: Optional[int]
|
|
in_collection: Optional[bool]
|
|
distance_limit: Optional[float]
|
|
|
|
|
|
class EmbeddingSearchService:
|
|
indexing_page_ids: list[int] = []
|
|
|
|
def __init__(self, dbs: DatabaseService, title: str):
|
|
self.dbs = dbs
|
|
|
|
self.title = title
|
|
self.base_title = title.split("/")[0]
|
|
|
|
self.title_index_helper = TitleIndexHelper(dbs)
|
|
self.title_collection_helper = TitleCollectionHelper(dbs)
|
|
self.page_index: PageIndexHelper = None
|
|
|
|
self.tiktoken: TikTokenService = None
|
|
|
|
self.mwapi = MediaWikiApi.create()
|
|
self.openai_api = OpenAIApi.create()
|
|
|
|
self.page_id: Optional[int] = None
|
|
self.collection_id: Optional[int] = None
|
|
|
|
self.title_index: Optional[TitleIndexModel] = None
|
|
self.collection_info: Optional[TitleCollectionModel] = None
|
|
|
|
self.page_info: dict = None
|
|
self.unindexed_docs: list = None
|
|
|
|
async def __aenter__(self):
|
|
self.tiktoken = await TikTokenService.create()
|
|
|
|
await self.title_index_helper.__aenter__()
|
|
await self.title_collection_helper.__aenter__()
|
|
|
|
self.title_index = await self.title_index_helper.find_by_title(self.title)
|
|
if self.title_index is None:
|
|
# Title may changed, get page info from page_id
|
|
await self.load_page_info()
|
|
self.title_index = await self.title_index_helper.find_by_page_id(
|
|
self.page_info["pageid"]
|
|
)
|
|
self.page_id = self.page_info["pageid"]
|
|
else:
|
|
self.page_id = self.title_index.page_id
|
|
self.collection_id = self.title_index.collection_id
|
|
self.page_index = PageIndexHelper(self.dbs, self.collection_id)
|
|
await self.page_index.__aenter__()
|
|
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
await self.title_index_helper.__aexit__(exc_type, exc, tb)
|
|
await self.title_collection_helper.__aexit__(exc_type, exc, tb)
|
|
|
|
if self.page_index is not None:
|
|
await self.page_index.__aexit__(exc_type, exc, tb)
|
|
|
|
async def page_index_exists(self, check_table=True):
|
|
if check_table:
|
|
return self.page_index and await self.page_index.table_exists()
|
|
else:
|
|
return self.page_index is not None
|
|
|
|
async def load_page_info(self, reload=False):
|
|
if self.page_info is None or reload:
|
|
self.page_info = await self.mwapi.get_page_info(self.title)
|
|
|
|
async def should_update_page_index(self, remote_update=False):
|
|
if not remote_update:
|
|
if self.title_index is None:
|
|
return True
|
|
return self.title_index.indexed_rev_id != self.title_index.latest_rev_id
|
|
else:
|
|
await self.load_page_info()
|
|
|
|
if (
|
|
self.title_index is not None
|
|
and await self.page_index_exists()
|
|
and self.title_index.indexed_rev_id == self.page_info["lastrevid"]
|
|
):
|
|
# Not changed
|
|
return False
|
|
|
|
return True
|
|
|
|
async def update_title_index(self, remote_update=False):
|
|
if not await self.should_update_page_index(remote_update):
|
|
return False
|
|
|
|
await self.load_page_info()
|
|
|
|
self.page_id = self.page_info["pageid"]
|
|
|
|
if self.page_id in self.indexing_page_ids:
|
|
raise EmbeddingRunningException("Page index is running now")
|
|
|
|
self.title: str = self.page_info["title"]
|
|
self.base_title = self.title.split("/")[0]
|
|
|
|
# Find collection by base title
|
|
self.collection_info = await self.title_collection_helper.find_by_title(
|
|
self.base_title
|
|
)
|
|
if self.collection_info is None:
|
|
# Create collection
|
|
self.collection_info = await self.title_collection_helper.add(
|
|
self.base_title
|
|
)
|
|
if self.collection_info is None:
|
|
raise Exception("Failed to create title collection")
|
|
self.collection_id = self.collection_info.id
|
|
|
|
if self.title_index == None:
|
|
# Create title record
|
|
self.title_index = TitleIndexModel(
|
|
title=self.page_info["title"],
|
|
page_id=self.page_id,
|
|
indexed_rev_id=None,
|
|
latest_rev_id=self.page_info["lastrevid"],
|
|
collection_id=self.collection_id,
|
|
embedding=None,
|
|
)
|
|
self.title_index = await self.title_index_helper.add(self.title_index)
|
|
if self.title_index is None:
|
|
raise Exception("Failed to create title index")
|
|
else:
|
|
self.title_index.latest_rev_id = self.page_info["lastrevid"]
|
|
# Title changed, remove embedding
|
|
# Title sha1 will be updated by model helper
|
|
if self.title_index.title != self.page_info["title"]:
|
|
self.title_index.title = self.page_info["title"]
|
|
self.title_index.embedding = None
|
|
|
|
# Collection changed, remove old index
|
|
if self.collection_id != self.title_index.collection_id:
|
|
async with PageIndexHelper(self.dbs, self.title_index.collection_id) as old_page_index:
|
|
await self.page_index.init_table()
|
|
old_page_index.remove_by_page_id(self.page_id)
|
|
self.title_index.collection_id = self.collection_id
|
|
|
|
await self.title_index_helper.update(self.title_index)
|
|
|
|
# Update collection main page id
|
|
if (
|
|
self.title == self.collection_info.title
|
|
and self.page_id != self.collection_info.page_id
|
|
):
|
|
await self.title_collection_helper.set_main_page_id(
|
|
self.base_title, self.page_id
|
|
)
|
|
|
|
if self.page_index:
|
|
await self.page_index.__aexit__(None, None, None)
|
|
self.page_index = PageIndexHelper(self.dbs, self.collection_id)
|
|
await self.page_index.__aenter__()
|
|
await self.page_index.init_table()
|
|
|
|
async def prepare_update_index(self):
|
|
await self.update_title_index()
|
|
|
|
page_content = await self.mwapi.parse_page(self.title)
|
|
|
|
self.sentences = getWikiSentences(page_content)
|
|
|
|
self.unindexed_docs = await self.page_index.get_unindexed_doc(
|
|
self.sentences, self.page_id, with_temporary=False
|
|
)
|
|
|
|
return True
|
|
|
|
async def get_unindexed_tokens(self):
|
|
if self.unindexed_docs is None:
|
|
return 0
|
|
else:
|
|
tokens = 0
|
|
for doc in self.unindexed_docs:
|
|
if "text" in doc:
|
|
tokens += await self.tiktoken.get_tokens(doc["text"])
|
|
|
|
return tokens
|
|
|
|
async def update_page_index(self, on_progress=None):
|
|
if self.unindexed_docs is None:
|
|
return False
|
|
|
|
chunk_limit = 500
|
|
|
|
chunk_len = 0
|
|
doc_chunk = []
|
|
total_token_usage = 0
|
|
processed_len = 0
|
|
|
|
async def on_embedding_progress(current, length):
|
|
nonlocal processed_len
|
|
|
|
indexed_docs = processed_len + current
|
|
|
|
if on_progress is not None:
|
|
await on_progress(indexed_docs, len(self.unindexed_docs))
|
|
|
|
async def embedding_doc(doc_chunk):
|
|
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(
|
|
doc_chunk, on_embedding_progress
|
|
)
|
|
await self.page_index.index_doc(doc_chunk, self.page_id)
|
|
|
|
return token_usage
|
|
|
|
if len(self.unindexed_docs) > 0:
|
|
if on_progress is not None:
|
|
await on_progress(0, len(self.unindexed_docs))
|
|
|
|
for doc in self.unindexed_docs:
|
|
chunk_len += len(doc)
|
|
|
|
if chunk_len > chunk_limit:
|
|
total_token_usage += await embedding_doc(doc_chunk)
|
|
processed_len += len(doc_chunk)
|
|
if on_progress is not None:
|
|
await on_progress(processed_len, len(self.unindexed_docs))
|
|
|
|
doc_chunk = []
|
|
chunk_len = len(doc)
|
|
|
|
doc_chunk.append(doc)
|
|
|
|
if len(doc_chunk) > 0:
|
|
total_token_usage += await embedding_doc(doc_chunk)
|
|
if on_progress is not None:
|
|
await on_progress(len(self.unindexed_docs), len(self.unindexed_docs))
|
|
|
|
await self.page_index.remove_outdated_doc(self.sentences, self.page_id)
|
|
|
|
# Update database
|
|
# This task may take a long time, refresh model to retrieve latest data
|
|
self.title_index = await self.title_index_helper.refresh(self.title_index)
|
|
|
|
self.title_index.indexed_rev_id = self.page_info["lastrevid"]
|
|
|
|
# Update title embedding
|
|
if await self.title_index.awaitable_attrs.embedding is None:
|
|
doc_chunk = [{"text": self.title}]
|
|
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
|
|
total_token_usage += token_usage
|
|
|
|
embedding = doc_chunk[0]["embedding"]
|
|
self.title_index.embedding = embedding
|
|
|
|
await self.title_index_helper.update(self.title_index)
|
|
|
|
return total_token_usage
|
|
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 10,
|
|
in_collection: bool = False,
|
|
distance_limit: float = 0.6,
|
|
):
|
|
if self.page_index is None:
|
|
raise Exception("Page index is not initialized")
|
|
|
|
query_doc = [{"text": query}]
|
|
query_doc, token_usage = await self.openai_api.get_embeddings(query_doc)
|
|
query_embedding = query_doc[0]["embedding"]
|
|
|
|
if query_embedding is None:
|
|
return [], token_usage
|
|
|
|
res = await self.page_index.search_text_embedding(
|
|
query_embedding, in_collection, limit, self.page_id
|
|
)
|
|
if res:
|
|
filtered = []
|
|
for one in res:
|
|
if one["distance"] < distance_limit:
|
|
filtered.append(dict(one))
|
|
return filtered, token_usage
|
|
else:
|
|
return res, token_usage
|