You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

227 lines
8.2 KiB
Python

from __future__ import annotations
from typing import Optional, TypedDict
from api.model.embedding_search.title_collection import TitleCollectionHelper, TitleCollectionModel
from api.model.embedding_search.title_index import TitleIndexHelper
from api.model.embedding_search.page_index import PageIndexHelper
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService
from utils.wiki import getWikiSentences
class EmbeddingSearchArgs(TypedDict):
limit: Optional[int]
in_collection: Optional[bool]
distance_limit: Optional[float]
class EmbeddingSearchService:
def __init__(self, dbs: DatabaseService, title: str):
self.dbs = dbs
self.title = title
self.base_title = title.split("/")[0]
self.title_index = TitleIndexHelper(dbs)
self.title_collection = TitleCollectionHelper(dbs)
self.page_index: PageIndexHelper = None
self.tiktoken: TikTokenService = None
self.mwapi = MediaWikiApi.create()
self.openai_api = OpenAIApi.create()
self.page_id: int = None
self.collection_id: int = None
self.title_info: dict = None
self.collection_info: TitleCollectionModel = None
self.page_info: dict = None
self.unindexed_docs: list = None
async def __aenter__(self):
self.tiktoken = await TikTokenService.create()
await self.title_index.__aenter__()
await self.title_collection.__aenter__()
self.title_info = await self.title_index.find_by_title(self.title)
if self.title_info is not None:
self.page_id = self.title_info["page_id"]
self.collection_id = self.title_info["collection_id"]
self.page_index = PageIndexHelper(
self.dbs, self.collection_id, self.page_id)
await self.page_index.__aenter__()
return self
async def __aexit__(self, exc_type, exc, tb):
await self.title_index.__aexit__(exc_type, exc, tb)
await self.title_collection.__aexit__(exc_type, exc, tb)
if self.page_index is not None:
await self.page_index.__aexit__(exc_type, exc, tb)
async def page_index_exists(self, check_table = True):
if check_table:
return self.page_index and await self.page_index.table_exists()
else:
return self.page_index is not None
async def load_page_info(self, reload=False):
if self.page_info is None or reload:
self.page_info = await self.mwapi.get_page_info(self.title)
async def should_update_page_index(self):
await self.load_page_info()
if (self.title_info is not None and await self.page_index_exists() and
self.title_info["title"] == self.page_info["title"] and self.title_info["rev_id"] == self.page_info["lastrevid"]):
# Not changed
return False
return True
async def prepare_update_index(self):
# Check rev_id
await self.load_page_info()
if not await self.should_update_page_index():
return False
self.page_id = self.page_info["pageid"]
# Create collection
self.collection_info = await self.title_collection.find_by_title(self.base_title)
if self.collection_info is None:
self.collection_id = await self.title_collection.add(self.base_title)
if self.collection_id is None:
raise Exception("Failed to create title collection")
else:
self.collection_id = self.collection_info.id
self.page_index = PageIndexHelper(
self.dbs, self.collection_id, self.page_id)
await self.page_index.__aenter__()
await self.page_index.init_table()
page_content = await self.mwapi.parse_page(self.title)
self.sentences = getWikiSentences(page_content)
self.unindexed_docs = await self.page_index.get_unindexed_doc(self.sentences, with_temporary=False)
return True
async def get_unindexed_tokens(self):
if self.unindexed_docs is None:
return 0
else:
tokens = 0
for doc in self.unindexed_docs:
if "text" in doc:
tokens += await self.tiktoken.get_tokens(doc["text"])
return tokens
async def update_page_index(self, on_progress=None):
if self.unindexed_docs is None:
return False
total_token_usage = 0
async def embedding_doc(doc_chunk):
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
await self.page_index.index_doc(doc_chunk)
return token_usage
if len(self.unindexed_docs) > 0:
if on_progress is not None:
await on_progress(0, len(self.unindexed_docs))
chunk_limit = 500
chunk_len = 0
processed_len = 0
doc_chunk = []
for doc in self.unindexed_docs:
chunk_len += len(doc)
if chunk_len > chunk_limit:
total_token_usage += await embedding_doc(doc_chunk)
processed_len += len(doc_chunk)
if on_progress is not None:
await on_progress(processed_len, len(self.unindexed_docs))
doc_chunk = []
chunk_len = len(doc)
doc_chunk.append(doc)
if len(doc_chunk) > 0:
total_token_usage += await embedding_doc(doc_chunk)
if on_progress is not None:
await on_progress(len(self.unindexed_docs), len(self.unindexed_docs))
await self.page_index.remove_outdated_doc(self.sentences)
# Update database
if self.title_info is None:
doc_chunk = [{"text": self.title}]
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
total_token_usage += token_usage
embedding = doc_chunk[0]["embedding"]
await self.title_index.add(self.page_info["title"],
self.page_id,
self.page_info["lastrevid"],
self.collection_id,
embedding)
else:
if self.title != self.page_info["title"]:
self.title = self.page_info["title"]
doc_chunk = [{"text": self.title}]
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
total_token_usage += token_usage
embedding = doc_chunk[0]["embedding"]
await self.title_index.update_title_data(self.page_id,
self.title,
self.page_info["lastrevid"],
self.collection_id,
embedding)
else:
await self.title_index.update_rev_id(self.title, self.page_info["lastrevid"])
if (self.collection_info is None or
(self.base_title == self.collection_info.title and self.page_id != self.collection_info.page_id)):
await self.title_collection.set_page_id(self.base_title, self.page_id)
return total_token_usage
async def search(self, query: str, limit: int = 10, in_collection: bool = False, distance_limit: float = 0.6):
if self.page_index is None:
raise Exception("Page index is not initialized")
query_doc = [{"text": query}]
query_doc, token_usage = await self.openai_api.get_embeddings(query_doc)
query_embedding = query_doc[0]["embedding"]
if query_embedding is None:
return [], token_usage
res = await self.page_index.search_text_embedding(query_embedding, in_collection, limit)
if res:
filtered = []
for one in res:
if one["distance"] < distance_limit:
filtered.append(dict(one))
return filtered, token_usage
else:
return res, token_usage