from __future__ import annotations from typing import Optional, TypedDict from api.model.embedding_search.title_collection import TitleCollectionHelper, TitleCollectionModel from api.model.embedding_search.title_index import TitleIndexHelper from api.model.embedding_search.page_index import PageIndexHelper from service.database import DatabaseService from service.mediawiki_api import MediaWikiApi from service.openai_api import OpenAIApi from service.tiktoken import TikTokenService from utils.wiki import getWikiSentences class EmbeddingRunningException(Exception): pass class EmbeddingSearchArgs(TypedDict): limit: Optional[int] in_collection: Optional[bool] distance_limit: Optional[float] class EmbeddingSearchService: indexing_page_ids: list[int] = [] def __init__(self, dbs: DatabaseService, title: str): self.dbs = dbs self.title = title self.base_title = title.split("/")[0] self.title_index = TitleIndexHelper(dbs) self.title_collection = TitleCollectionHelper(dbs) self.page_index: PageIndexHelper = None self.tiktoken: TikTokenService = None self.mwapi = MediaWikiApi.create() self.openai_api = OpenAIApi.create() self.page_id: int = None self.collection_id: int = None self.title_info: dict = None self.collection_info: TitleCollectionModel = None self.page_info: dict = None self.unindexed_docs: list = None async def __aenter__(self): self.tiktoken = await TikTokenService.create() await self.title_index.__aenter__() await self.title_collection.__aenter__() self.title_info = await self.title_index.find_by_title(self.title) if self.title_info is not None: self.page_id = self.title_info["page_id"] self.collection_id = self.title_info["collection_id"] self.page_index = PageIndexHelper( self.dbs, self.collection_id, self.page_id) await self.page_index.__aenter__() return self async def __aexit__(self, exc_type, exc, tb): await self.title_index.__aexit__(exc_type, exc, tb) await self.title_collection.__aexit__(exc_type, exc, tb) if self.page_index is not None: await self.page_index.__aexit__(exc_type, exc, tb) async def page_index_exists(self, check_table = True): if check_table: return self.page_index and await self.page_index.table_exists() else: return self.page_index is not None async def load_page_info(self, reload=False): if self.page_info is None or reload: self.page_info = await self.mwapi.get_page_info(self.title) async def should_update_page_index(self): await self.load_page_info() if (self.title_info is not None and await self.page_index_exists() and self.title_info["title"] == self.page_info["title"] and self.title_info["rev_id"] == self.page_info["lastrevid"]): # Not changed return False return True async def prepare_update_index(self): # Check rev_id await self.load_page_info() if not await self.should_update_page_index(): return False self.page_id = self.page_info["pageid"] if self.page_id in self.indexing_page_ids: raise EmbeddingRunningException("Page index is running now") # Create collection self.collection_info = await self.title_collection.find_by_title(self.base_title) if self.collection_info is None: self.collection_id = await self.title_collection.add(self.base_title) if self.collection_id is None: raise Exception("Failed to create title collection") else: self.collection_id = self.collection_info.id self.page_index = PageIndexHelper( self.dbs, self.collection_id, self.page_id) await self.page_index.__aenter__() await self.page_index.init_table() page_content = await self.mwapi.parse_page(self.title) self.sentences = getWikiSentences(page_content) self.unindexed_docs = await self.page_index.get_unindexed_doc(self.sentences, with_temporary=False) return True async def get_unindexed_tokens(self): if self.unindexed_docs is None: return 0 else: tokens = 0 for doc in self.unindexed_docs: if "text" in doc: tokens += await self.tiktoken.get_tokens(doc["text"]) return tokens async def update_page_index(self, on_progress=None): if self.unindexed_docs is None: return False chunk_limit = 500 chunk_len = 0 doc_chunk = [] total_token_usage = 0 processed_len = 0 async def on_embedding_progress(current, length): nonlocal processed_len indexed_docs = processed_len + current if on_progress is not None: await on_progress(indexed_docs, len(self.unindexed_docs)) async def embedding_doc(doc_chunk): (doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk, on_embedding_progress) await self.page_index.index_doc(doc_chunk) return token_usage if len(self.unindexed_docs) > 0: if on_progress is not None: await on_progress(0, len(self.unindexed_docs)) for doc in self.unindexed_docs: chunk_len += len(doc) if chunk_len > chunk_limit: total_token_usage += await embedding_doc(doc_chunk) processed_len += len(doc_chunk) if on_progress is not None: await on_progress(processed_len, len(self.unindexed_docs)) doc_chunk = [] chunk_len = len(doc) doc_chunk.append(doc) if len(doc_chunk) > 0: total_token_usage += await embedding_doc(doc_chunk) if on_progress is not None: await on_progress(len(self.unindexed_docs), len(self.unindexed_docs)) await self.page_index.remove_outdated_doc(self.sentences) # Update database if self.title_info is None: doc_chunk = [{"text": self.title}] (doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk) total_token_usage += token_usage embedding = doc_chunk[0]["embedding"] await self.title_index.add(self.page_info["title"], self.page_id, self.page_info["lastrevid"], self.collection_id, embedding) else: if self.title != self.page_info["title"]: self.title = self.page_info["title"] doc_chunk = [{"text": self.title}] (doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk) total_token_usage += token_usage embedding = doc_chunk[0]["embedding"] await self.title_index.update_title_data(self.page_id, self.title, self.page_info["lastrevid"], self.collection_id, embedding) else: await self.title_index.update_rev_id(self.title, self.page_info["lastrevid"]) if (self.collection_info is None or (self.base_title == self.collection_info.title and self.page_id != self.collection_info.page_id)): await self.title_collection.set_page_id(self.base_title, self.page_id) return total_token_usage async def search(self, query: str, limit: int = 10, in_collection: bool = False, distance_limit: float = 0.6): if self.page_index is None: raise Exception("Page index is not initialized") query_doc = [{"text": query}] query_doc, token_usage = await self.openai_api.get_embeddings(query_doc) query_embedding = query_doc[0]["embedding"] if query_embedding is None: return [], token_usage res = await self.page_index.search_text_embedding(query_embedding, in_collection, limit) if res: filtered = [] for one in res: if one["distance"] < distance_limit: filtered.append(dict(one)) return filtered, token_usage else: return res, token_usage