from __future__ import annotations from typing import Optional, TypedDict import sqlalchemy from api.model.embedding_search.title_collection import ( TitleCollectionHelper, TitleCollectionModel, ) from api.model.embedding_search.title_index import TitleIndexHelper, TitleIndexModel from api.model.embedding_search.page_index import PageIndexHelper from service.database import DatabaseService from service.mediawiki_api import MediaWikiApi from service.openai_api import OpenAIApi from service.tiktoken import TikTokenService from utils.wiki import getWikiSentences class EmbeddingRunningException(Exception): pass class EmbeddingSearchArgs(TypedDict): limit: Optional[int] in_collection: Optional[bool] distance_limit: Optional[float] class EmbeddingSearchService: indexing_page_ids: list[int] = [] def __init__(self, dbs: DatabaseService, title: str): self.dbs = dbs self.title = title self.base_title = title.split("/")[0] self.title_index_helper = TitleIndexHelper(dbs) self.title_collection_helper = TitleCollectionHelper(dbs) self.page_index: PageIndexHelper = None self.tiktoken: TikTokenService = None self.mwapi = MediaWikiApi.create() self.openai_api = OpenAIApi.create() self.page_id: Optional[int] = None self.collection_id: Optional[int] = None self.title_index: Optional[TitleIndexModel] = None self.collection_info: Optional[TitleCollectionModel] = None self.page_info: dict = None self.unindexed_docs: list = None async def __aenter__(self): self.tiktoken = await TikTokenService.create() await self.title_index_helper.__aenter__() await self.title_collection_helper.__aenter__() self.title_index = await self.title_index_helper.find_by_title(self.title) if self.title_index is None: # Title may changed, get page info from page_id await self.load_page_info() self.title_index = await self.title_index_helper.find_by_page_id( self.page_info["pageid"] ) self.page_id = self.page_info["pageid"] else: self.page_id = self.title_index.page_id self.collection_id = self.title_index.collection_id self.page_index = PageIndexHelper(self.dbs, self.collection_id) await self.page_index.__aenter__() return self async def __aexit__(self, exc_type, exc, tb): await self.title_index_helper.__aexit__(exc_type, exc, tb) await self.title_collection_helper.__aexit__(exc_type, exc, tb) if self.page_index is not None: await self.page_index.__aexit__(exc_type, exc, tb) async def page_index_exists(self, check_table=True): if check_table: return self.page_index and await self.page_index.table_exists() else: return self.page_index is not None async def load_page_info(self, reload=False): if self.page_info is None or reload: self.page_info = await self.mwapi.get_page_info(self.title) async def should_update_page_index(self, remote_update=False): if not remote_update: if self.title_index is None: return True return self.title_index.indexed_rev_id != self.title_index.latest_rev_id else: await self.load_page_info() if ( self.title_index is not None and await self.page_index_exists() and self.title_index.indexed_rev_id == self.page_info["lastrevid"] ): # Not changed return False return True async def update_title_index(self, remote_update=False): if not await self.should_update_page_index(remote_update): return False await self.load_page_info() self.page_id = self.page_info["pageid"] if self.page_id in self.indexing_page_ids: raise EmbeddingRunningException("Page index is running now") self.title: str = self.page_info["title"] self.base_title = self.title.split("/")[0] # Find collection by base title self.collection_info = await self.title_collection_helper.find_by_title( self.base_title ) if self.collection_info is None: # Create collection self.collection_info = await self.title_collection_helper.add( self.base_title ) if self.collection_info is None: raise Exception("Failed to create title collection") self.collection_id = self.collection_info.id if self.title_index == None: # Create title record self.title_index = TitleIndexModel( title=self.page_info["title"], page_id=self.page_id, indexed_rev_id=None, latest_rev_id=self.page_info["lastrevid"], collection_id=self.collection_id, embedding=None, ) self.title_index = await self.title_index_helper.add(self.title_index) if self.title_index is None: raise Exception("Failed to create title index") else: self.title_index.latest_rev_id = self.page_info["lastrevid"] # Title changed, remove embedding # Title sha1 will be updated by model helper if self.title_index.title != self.page_info["title"]: self.title_index.title = self.page_info["title"] self.title_index.embedding = None # Collection changed, remove old index if self.collection_id != self.title_index.collection_id: async with PageIndexHelper(self.dbs, self.title_index.collection_id) as old_page_index: await old_page_index.init_table() old_page_index.remove_by_page_id(self.page_id) self.title_index.collection_id = self.collection_id await self.title_index_helper.update(self.title_index) # Update collection main page id if ( self.title == self.collection_info.title and self.page_id != self.collection_info.page_id ): await self.title_collection_helper.set_main_page_id( self.base_title, self.page_id ) if self.page_index: await self.page_index.__aexit__(None, None, None) self.page_index = PageIndexHelper(self.dbs, self.collection_id) await self.page_index.__aenter__() await self.page_index.init_table() async def prepare_update_index(self): await self.update_title_index() page_content = await self.mwapi.parse_page(self.title) self.sentences = getWikiSentences(page_content) self.unindexed_docs = await self.page_index.get_unindexed_doc( self.sentences, self.page_id, with_temporary=False ) return True async def get_unindexed_tokens(self): if self.unindexed_docs is None: return 0 else: tokens = 0 for doc in self.unindexed_docs: if "text" in doc: tokens += await self.tiktoken.get_tokens(doc["text"]) return tokens async def update_page_index(self, on_progress=None): if self.unindexed_docs is None: return False chunk_limit = 500 chunk_len = 0 doc_chunk = [] total_token_usage = 0 processed_len = 0 async def on_embedding_progress(current, length): nonlocal processed_len indexed_docs = processed_len + current if on_progress is not None: await on_progress(indexed_docs, len(self.unindexed_docs)) async def embedding_doc(doc_chunk): (doc_chunk, token_usage) = await self.openai_api.get_embeddings( doc_chunk, on_embedding_progress ) await self.page_index.index_doc(doc_chunk, self.page_id) return token_usage if len(self.unindexed_docs) > 0: if on_progress is not None: await on_progress(0, len(self.unindexed_docs)) for doc in self.unindexed_docs: chunk_len += len(doc) if chunk_len > chunk_limit: total_token_usage += await embedding_doc(doc_chunk) processed_len += len(doc_chunk) if on_progress is not None: await on_progress(processed_len, len(self.unindexed_docs)) doc_chunk = [] chunk_len = len(doc) doc_chunk.append(doc) if len(doc_chunk) > 0: total_token_usage += await embedding_doc(doc_chunk) if on_progress is not None: await on_progress(len(self.unindexed_docs), len(self.unindexed_docs)) await self.page_index.remove_outdated_doc(self.sentences, self.page_id) # Update database # This task may take a long time, refresh model to retrieve latest data self.title_index = await self.title_index_helper.refresh(self.title_index) self.title_index.indexed_rev_id = self.page_info["lastrevid"] # Update title embedding if await self.title_index.awaitable_attrs.embedding is None: doc_chunk = [{"text": self.title}] (doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk) total_token_usage += token_usage embedding = doc_chunk[0]["embedding"] self.title_index.embedding = embedding await self.title_index_helper.update(self.title_index) return total_token_usage async def search( self, query: str, limit: int = 10, in_collection: bool = False, distance_limit: float = 0.6, ): if self.page_index is None: raise Exception("Page index is not initialized") query_doc = [{"text": query}] query_doc, token_usage = await self.openai_api.get_embeddings(query_doc) query_embedding = query_doc[0]["embedding"] if query_embedding is None: return [], token_usage res = await self.page_index.search_text_embedding( query_embedding, in_collection, limit, self.page_id ) if res: filtered = [] for one in res: if one["distance"] < distance_limit: filtered.append(dict(one)) return filtered, token_usage else: return res, token_usage