From 6e0a21d84a78c4e4097ed0c88b562d1a3f591594 Mon Sep 17 00:00:00 2001 From: Lex Lim Date: Sun, 25 Jun 2023 15:20:12 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=88=B7=E6=96=B0=E9=A1=B5?= =?UTF-8?q?=E9=9D=A2=E7=B4=A2=E5=BC=95=E6=97=B6=E7=9A=84=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/model/embedding_search/title_collection.py | 2 +- maintenance/update_title_index.py | 6 +++++- service/embedding_search.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/api/model/embedding_search/title_collection.py b/api/model/embedding_search/title_collection.py index 0523d57..cd928e7 100644 --- a/api/model/embedding_search/title_collection.py +++ b/api/model/embedding_search/title_collection.py @@ -14,7 +14,7 @@ class TitleCollectionModel(BaseModel): page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True, nullable=True) class TitleCollectionHelper(BaseHelper): - async def add(self, title: str, page_id: Optional[int] = None) -> Union[int, bool]: + async def add(self, title: str, page_id: Optional[int] = None) -> TitleCollectionModel | None: stmt = select(TitleCollectionModel.id).where(TitleCollectionModel.title == title) result = await self.session.scalar(stmt) diff --git a/maintenance/update_title_index.py b/maintenance/update_title_index.py index e2fa1f9..7114ba3 100644 --- a/maintenance/update_title_index.py +++ b/maintenance/update_title_index.py @@ -1,4 +1,5 @@ import asyncio +import sys import base as _ import local from service.database import DatabaseService @@ -10,9 +11,12 @@ async def main(): dbs = await DatabaseService.create() mw_api = MediaWikiApi.create() + start_title = sys.argv[1] if len(sys.argv) > 1 else None + continue_key = None while True: - page_res = await mw_api.get_all_pages(continue_key) + page_res = await mw_api.get_all_pages(continue_key, start_title) + start_title = None title_list = page_res["title_list"] diff --git a/service/embedding_search.py b/service/embedding_search.py index a1d5839..319200b 100644 --- a/service/embedding_search.py +++ b/service/embedding_search.py @@ -160,7 +160,7 @@ class EmbeddingSearchService: # Collection changed, remove old index if self.collection_id != self.title_index.collection_id: async with PageIndexHelper(self.dbs, self.title_index.collection_id) as old_page_index: - await self.page_index.init_table() + await old_page_index.init_table() old_page_index.remove_by_page_id(self.page_id) self.title_index.collection_id = self.collection_id