From 4aa627a23ee28c4a5a45d8f9699cdbb3de817aed Mon Sep 17 00:00:00 2001 From: Lex Lim Date: Sun, 25 Jun 2023 09:40:34 +0000 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=9C=A8Collection=E4=B8=AD?= =?UTF-8?q?=E6=8F=90=E9=97=AE=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/controller/EmbeddingSearch.py | 172 ++++++++++----- api/controller/task/ChatCompleteTask.py | 3 +- api/model/base.py | 3 +- api/model/embedding_search/page_index.py | 196 +++++++++++------ .../embedding_search/title_collection.py | 10 +- api/model/embedding_search/title_index.py | 29 ++- api/route.py | 1 + maintenance/base.py | 4 + maintenance/update_title_index.py | 35 +++ requirements.txt | 2 +- service/chat_complete.py | 6 +- service/embedding_search.py | 201 +++++++++++------- service/mediawiki_api.py | 36 ++++ 13 files changed, 491 insertions(+), 207 deletions(-) create mode 100644 maintenance/base.py create mode 100644 maintenance/update_title_index.py diff --git a/api/controller/EmbeddingSearch.py b/api/controller/EmbeddingSearch.py index ba55def..0984ecd 100644 --- a/api/controller/EmbeddingSearch.py +++ b/api/controller/EmbeddingSearch.py @@ -1,6 +1,8 @@ import sys import traceback from aiohttp import web +from api.model.embedding_search.title_collection import TitleCollectionHelper +from api.model.embedding_search.title_index import TitleIndexHelper from service.database import DatabaseService from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException @@ -14,9 +16,15 @@ class EmbeddingSearch: "title": { "required": True, }, + "collection": { + "required": False, + "type": bool, + "default": False + } }) - page_title = params.get('title') + page_title = params.get("title") + is_collection = params.get("collection") mwapi = MediaWikiApi.create() db = await DatabaseService.create(request.app) @@ -27,47 +35,60 @@ class EmbeddingSearch: try: transatcion_id = None - async with EmbeddingSearchService(db, page_title) as embedding_search: - if await embedding_search.should_update_page_index(): - if request.get("caller") == "user": - user_id = request.get("user") - usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage") - transatcion_id = usage_res.get("transaction_id") + title_list = [page_title] + if is_collection: + # Get collection titles + async with TitleCollectionHelper(db) as title_collection, TitleIndexHelper(db) as title_index_helper: + title_collection = await title_collection.find_by_title(page_title) + if title_collection is not None: + need_update_pages = await title_index_helper.get_need_update_index_list(title_collection.id) + title_list = [] + for page_info in need_update_pages: + title_list.append(page_info.title) - await embedding_search.prepare_update_index() + page_count = len(title_list) + page_current = 0 + index_updated = False + for one_title in title_list: + page_current += 1 + async with EmbeddingSearchService(db, one_title) as embedding_search: + if await embedding_search.should_update_page_index(): + if request.get("caller") == "user": + user_id = request.get("user") + usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage") + transatcion_id = usage_res.get("transaction_id") - async def on_progress(current, total): + await embedding_search.prepare_update_index() - await ws.send_json({ - 'event': 'progress', - 'current': current, - 'total': total - }) + async def on_progress(current, total): + await ws.send_json({ + "event": "progress", + "current": current, + "total": total, + "current_page": page_current, + "total_page": page_count, + }) - token_usage = await embedding_search.update_page_index(on_progress) - await ws.send_json({ - 'event': 'done', - 'status': 1, - 'index_updated': True - }) + token_usage = await embedding_search.update_page_index(on_progress) + index_updated = True - if transatcion_id: - await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage) - else: - await ws.send_json({ - 'event': 'done', - 'status': 1, - 'index_updated': False - }) + if transatcion_id: + await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage) + + await ws.send_json({ + "event": "done", + "status": 1, + "index_updated": index_updated, + }) except MediaWikiPageNotFoundException: error_msg = "Page \"%s\" not found." % page_title await ws.send_json({ - 'event': 'error', - 'status': -2, - 'message': error_msg, - 'error': { - 'code': 'page_not_found', - 'title': page_title, + "event": "error", + "status": -2, + "message": error_msg, + "error": { + "code": "page_not_found", + "title": page_title, }, }) if transatcion_id: @@ -77,12 +98,12 @@ class EmbeddingSearch: print(error_msg, file=sys.stderr) traceback.print_exc() await ws.send_json({ - 'event': 'error', - 'status': -3, - 'message': error_msg, - 'error': { - 'code': e.code, - 'info': e.info, + "event": "error", + "status": -3, + "message": error_msg, + "error": { + "code": e.code, + "info": e.info, }, }) if transatcion_id: @@ -90,11 +111,11 @@ class EmbeddingSearch: except EmbeddingRunningException: error_msg = "Page index is running now" await ws.send_json({ - 'event': 'error', - 'status': -4, - 'message': error_msg, - 'error': { - 'code': 'page_index_running', + "event": "error", + "status": -4, + "message": error_msg, + "error": { + "code": "page_index_running", }, }) if transatcion_id: @@ -104,11 +125,11 @@ class EmbeddingSearch: print(error_msg, file=sys.stderr) traceback.print_exc() await ws.send_json({ - 'event': 'error', - 'status': -1, - 'message': error_msg, - 'error': { - 'code': 'internal_server_error', + "event": "error", + "status": -1, + "message": error_msg, + "error": { + "code": "internal_server_error", } }) if transatcion_id: @@ -133,9 +154,15 @@ class EmbeddingSearch: if transatcion_id: result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage) - return await utils.web.api_response(1, {"data_indexed": True}) + return await utils.web.api_response(1, { + "data_indexed": True, + "title": embedding_search.title + }) else: - return await utils.web.api_response(1, {"data_indexed": False}) + return await utils.web.api_response(1, { + "data_indexed": False, + "title": embedding_search.title + }) except MediaWikiPageNotFoundException: error_msg = "Page \"%s\" not found." % page_title if transatcion_id: @@ -253,4 +280,43 @@ class EmbeddingSearch: "code": "internal-server-error", "message": error_msg }, request=request, http_status=500) - return await utils.web.api_response(1, data={"results": results, "token_usage": token_usage}, request=request) \ No newline at end of file + return await utils.web.api_response(1, data={"results": results, "token_usage": token_usage}, request=request) + + @staticmethod + @utils.web.token_auth + async def sys_update_title_info(request: web.Request): + params = await utils.web.get_param(request, { + "title": { + "required": True, + }, + }) + + if request.get("caller") == "user": + return await utils.web.api_response(-1, error={ + "code": "permission-denied", + "message": "This api is only for system caller." + }, request=request, http_status=403) + + page_title = params.get("title") + + db = await DatabaseService.create(request.app) + + async with EmbeddingSearchService(db, page_title) as embedding_search: + try: + await embedding_search.update_title_index(True) + except MediaWikiPageNotFoundException: + error_msg = "Page \"%s\" not found." % page_title + + return await utils.web.api_response(-2, error={ + "code": "page-not-found", + "title": page_title, + "message": error_msg + }, request=request, http_status=404) + except Exception as err: + err_msg = str(err) + return await utils.web.api_response(-1, error={ + "code": "internal-server-error", + "message": err_msg + }, request=request, http_status=500) + + return await utils.web.api_response(1, request=request) \ No newline at end of file diff --git a/api/controller/task/ChatCompleteTask.py b/api/controller/task/ChatCompleteTask.py index b1ab227..f8c3a8f 100644 --- a/api/controller/task/ChatCompleteTask.py +++ b/api/controller/task/ChatCompleteTask.py @@ -129,7 +129,8 @@ class ChatCompleteTask: async def end(self): await self.chat_complete_service.__aexit__(None, None, None) - del chat_complete_tasks[self.task_id] + if self.task_id in chat_complete_tasks: + del chat_complete_tasks[self.task_id] self.is_finished = True self.finished_time = time.time() diff --git a/api/model/base.py b/api/model/base.py index f4fcd96..cda20d6 100644 --- a/api/model/base.py +++ b/api/model/base.py @@ -2,10 +2,11 @@ from __future__ import annotations from typing import TypeVar import sqlalchemy from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.ext.asyncio import AsyncAttrs from service.database import DatabaseService -class BaseModel(DeclarativeBase): +class BaseModel(AsyncAttrs, DeclarativeBase): pass class BaseHelper: diff --git a/api/model/embedding_search/page_index.py b/api/model/embedding_search/page_index.py index 0081050..541f375 100644 --- a/api/model/embedding_search/page_index.py +++ b/api/model/embedding_search/page_index.py @@ -21,7 +21,9 @@ page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {} class AbstractPageIndexModel(BaseModel): __abstract__ = True - id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) + id: Mapped[int] = mapped_column( + sqlalchemy.Integer, primary_key=True, autoincrement=True + ) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True) embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)) @@ -36,34 +38,37 @@ def create_page_index_model(collection_id: int): if collection_id in page_index_model_list: return page_index_model_list[collection_id] else: + class _PageIndexModel(AbstractPageIndexModel): __tablename__ = "embedding_search_page_index_%s" % str(collection_id) - - embedding_index = sqlalchemy.Index(__tablename__ + "_embedding_idx", AbstractPageIndexModel.embedding, - postgresql_using='ivfflat', - postgresql_ops={'embedding': 'vector_cosine_ops'}) + + embedding_index = sqlalchemy.Index( + __tablename__ + "_embedding_idx", + AbstractPageIndexModel.embedding, + postgresql_using="ivfflat", + postgresql_ops={"embedding": "vector_cosine_ops"}, + ) page_index_model_list[collection_id] = _PageIndexModel return _PageIndexModel + class PageIndexHelper: columns = [ "id", - "page_id" - "sha1", + "page_id" "sha1", "text", "text_len", "markdown", "markdown_len", "embedding", - "temp_doc_session_id" + "temp_doc_session_id", ] - def __init__(self, dbs: DatabaseService, collection_id: int, page_id: Optional[int]): + def __init__(self, dbs: DatabaseService, collection_id: int): self.dbs = dbs self.collection_id = collection_id - self.page_id = page_id if page_id is not None else -1 self.table_name = "embedding_search_page_index_%s" % str(collection_id) self.initialized = False self.table_initialized = False @@ -71,10 +76,11 @@ class PageIndexHelper: """ Initialize table """ + async def __aenter__(self): if self.initialized: return - + self.dbpool = self.dbs.pool.acquire() self.dbi = await self.dbpool.__aenter__() @@ -95,12 +101,16 @@ class PageIndexHelper: await self.session.__aexit__(exc_type, exc, tb) async def table_exists(self): - exists = await self.dbi.fetchval("""SELECT EXISTS ( + exists = await self.dbi.fetchval( + """SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1 - );""", self.table_name, column=0) + );""", + self.table_name, + column=0, + ) return bool(exists) @@ -113,22 +123,6 @@ class PageIndexHelper: async with self.dbs.engine.begin() as conn: await conn.run_sync(self.orm.__table__.create) - # await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ ( - # id SERIAL PRIMARY KEY, - # page_id INTEGER NOT NULL, - # sha1 VARCHAR(40) NOT NULL, - # text TEXT NOT NULL, - # text_len INTEGER NOT NULL, - # embedding VECTOR(%d) NOT NULL, - # markdown TEXT NULL, - # markdown_len INTEGER NULL, - # temp_doc_session_id INTEGER NULL - # ); - # CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id); - # CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1); - # CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id); - # """ % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name)) - self.table_initialized = True async def create_embedding_index(self): @@ -142,7 +136,9 @@ class PageIndexHelper: sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest() item["sha1"] = sha1 - async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False): + async def get_indexed_sha1( + self, page_id: int, with_temporary: bool = True, in_collection: bool = False + ): indexed_sha1_list = [] stmt = select(self.orm).column(self.orm.sha1) @@ -151,7 +147,7 @@ class PageIndexHelper: stmt = stmt.where(self.orm.temp_doc_session_id == None) if not in_collection: - stmt = stmt.where(self.orm.page_id == self.page_id) + stmt = stmt.where(self.orm.page_id == page_id) ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt) @@ -160,8 +156,10 @@ class PageIndexHelper: return indexed_sha1_list - async def get_unindexed_doc(self, doc: list, with_temporary: bool = True): - indexed_sha1_list = await self.get_indexed_sha1(with_temporary) + async def get_unindexed_doc( + self, doc: list, page_id: int, with_temporary: bool = True + ): + indexed_sha1_list = await self.get_indexed_sha1(page_id, with_temporary) self.sha1_doc(doc) should_index = [] @@ -171,10 +169,10 @@ class PageIndexHelper: return should_index - async def remove_outdated_doc(self, doc: list): - await self.clear_temp() + async def remove_outdated_doc(self, doc: list, page_id: int): + await self.clear_temp(page_id=page_id) - indexed_sha1_list = await self.get_indexed_sha1(False) + indexed_sha1_list = await self.get_indexed_sha1(page_id, False) self.sha1_doc(doc) doc_sha1_list = [item["sha1"] for item in doc] @@ -185,17 +183,24 @@ class PageIndexHelper: should_remove.append(sha1) if len(should_remove) > 0: - await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)" % (self.table_name), - self.page_id, should_remove) - - async def index_doc(self, doc: list): + await self.dbi.execute( + "DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)" + % (self.table_name), + page_id, + should_remove, + ) + + async def index_doc(self, doc: list, page_id: int): need_create_index = False indexed_persist_sha1_list = [] indexed_temp_sha1_list = [] - ret = await self.dbi.fetch("SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1" % (self.table_name), - self.page_id) + ret = await self.dbi.fetch( + "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1" + % (self.table_name), + page_id, + ) for row in ret: if row[1]: indexed_temp_sha1_list.append(row[0]) @@ -226,28 +231,48 @@ class PageIndexHelper: should_remove.append(sha1) if len(should_index) > 0: - await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) - VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name), - [(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"]) for item in should_index]) + await self.dbi.executemany( + """INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" + % (self.table_name), + [ + ( + item["sha1"], + page_id, + item["text"], + len(item["text"]), + item["markdown"], + len(item["markdown"]), + item["embedding"], + ) + for item in should_index + ], + ) if len(should_persist) > 0: - await self.dbi.executemany("UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name), - [(self.page_id, sha1) for sha1 in should_persist]) - + await self.dbi.executemany( + "UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" + % (self.table_name), + [(page_id, sha1) for sha1 in should_persist], + ) + if need_create_index: await self.create_embedding_index() """ Add temporary document to the index """ + async def index_temp_doc(self, doc: list, temp_doc_session_id: int): indexed_sha1_list = [] indexed_temp_sha1_list = [] doc_sha1_list = [] - sql = "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1 AND (temp_doc_session_id = $2 OR temp_doc_session_id IS NULL)" % ( - self.table_name) - ret = await self.dbi.fetch(sql, self.page_id, temp_doc_session_id) + sql = ( + "SELECT sha1, temp_doc_session_id FROM %s WHERE (temp_doc_session_id = $1 OR temp_doc_session_id IS NULL)" + % (self.table_name) + ) + ret = await self.dbi.fetch(sql, temp_doc_session_id) for row in ret: indexed_sha1_list.append(row[0]) if row[1]: @@ -269,41 +294,79 @@ class PageIndexHelper: should_remove.append(sha1) if len(should_index) > 0: - await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8);""" % (self.table_name), - [(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"], temp_doc_session_id) for item in should_index]) + await self.dbi.executemany( + """INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8);""" + % (self.table_name), + [ + ( + item["sha1"], + item["page_id"], + item["text"], + len(item["text"]), + item["markdown"], + len(item["markdown"]), + item["embedding"], + temp_doc_session_id, + ) + for item in should_index + ], + ) if len(should_remove) > 0: - await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND temp_doc_session_id = $2 AND sha1 = ANY($3)" % (self.table_name), - self.page_id, temp_doc_session_id, should_remove) + await self.dbi.execute( + "DELETE FROM %s WHERE temp_doc_session_id = $1 AND sha1 = ANY($2)" + % (self.table_name), + temp_doc_session_id, + should_remove, + ) """ Search for text by consine similary """ - async def search_text_embedding(self, embedding: np.ndarray, in_collection: bool = False, limit: int = 10): + + async def search_text_embedding( + self, + embedding: np.ndarray, + in_collection: bool = False, + limit: int = 10, + page_id: Optional[int] = None, + ): if in_collection: - return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance + return await self.dbi.fetch( + """SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance FROM %s ORDER BY distance ASC - LIMIT %d""" % (self.table_name, limit), embedding) + LIMIT %d""" + % (self.table_name, limit), + embedding, + ) else: - return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance + return await self.dbi.fetch( + """SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance FROM %s WHERE page_id = $2 ORDER BY distance ASC - LIMIT %d""" % (self.table_name, limit), embedding, self.page_id) + LIMIT %d""" + % (self.table_name, limit), + embedding, + page_id, + ) """ Clear temporary index """ - async def clear_temp(self, in_collection: bool = False, temp_doc_session_id: int = None): + + async def clear_temp( + self, in_collection: bool = False, temp_doc_session_id: int = None, page_id: Optional[int] = None + ): sql = "DELETE FROM %s" % (self.table_name) where = [] params = [] - if not in_collection: - params.append(self.page_id) + if not in_collection and page_id: + params.append(page_id) where.append("page_id = $%d" % len(params)) if temp_doc_session_id: @@ -316,3 +379,8 @@ class PageIndexHelper: sql += " WHERE " + (" AND ".join(where)) await self.dbi.execute(sql, *params) + + async def remove_by_page_id(self, page_id: int): + stmt = delete(self.orm).where(self.orm.page_id == page_id) + await self.session.execute(stmt) + await self.session.commit() diff --git a/api/model/embedding_search/title_collection.py b/api/model/embedding_search/title_collection.py index fad054a..0523d57 100644 --- a/api/model/embedding_search/title_collection.py +++ b/api/model/embedding_search/title_collection.py @@ -23,11 +23,11 @@ class TitleCollectionHelper(BaseHelper): self.session.add(obj) await self.session.commit() await self.session.refresh(obj) - return obj.id + return obj - return False + return None - async def set_page_id(self, title: str, page_id: Optional[str] = None): + async def set_main_page_id(self, title: str, page_id: Optional[str] = None): stmt = update(TitleCollectionModel).where(TitleCollectionModel.title == title).values(page_id=page_id) await self.session.execute(stmt) await self.session.commit() @@ -37,6 +37,10 @@ class TitleCollectionHelper(BaseHelper): await self.session.execute(stmt) await self.session.commit() + async def find_by_id(self, id: int): + stmt = select(TitleCollectionModel).where(TitleCollectionModel.id == id) + return await self.session.scalar(stmt) + async def find_by_title(self, title: str): stmt = select(TitleCollectionModel).where(TitleCollectionModel.title == title) return await self.session.scalar(stmt) diff --git a/api/model/embedding_search/title_index.py b/api/model/embedding_search/title_index.py index cce023f..b79aa0d 100644 --- a/api/model/embedding_search/title_index.py +++ b/api/model/embedding_search/title_index.py @@ -5,7 +5,7 @@ import numpy as np from pgvector.sqlalchemy import Vector from pgvector.asyncpg import register_vector import sqlalchemy -from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred +from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer from sqlalchemy.ext.asyncio import AsyncEngine import config @@ -20,9 +20,9 @@ class TitleIndexModel(BaseModel): title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) - rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer) + indexed_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True) latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer) - embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))) + embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE), nullable=True)) embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding, postgresql_using='ivfflat', @@ -37,7 +37,7 @@ class TitleIndexHelper(BaseHelper): "title", "page_id", "collection_id", - "rev_id", + "indexed_rev_id", "latest_rev_id" "embedding", ] @@ -52,11 +52,11 @@ class TitleIndexHelper(BaseHelper): await register_vector(self.dbi) - await super().__aenter__() + return await super().__aenter__() async def __aexit__(self, exc_type, exc, tb): await self.dbpool.__aexit__(exc_type, exc, tb) - await super().__aexit__(exc_type, exc, tb) + return await super().__aexit__(exc_type, exc, tb) def get_columns(self, exclude=[]): if len(exclude) == 0: @@ -92,6 +92,10 @@ class TitleIndexHelper(BaseHelper): await self.session.commit() await self.session.refresh(obj) return obj + + async def refresh(self, obj: TitleIndexModel): + await self.session.refresh(obj) + return obj """ Search for titles by consine similary @@ -120,6 +124,15 @@ class TitleIndexHelper(BaseHelper): stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.page_id == page_id) return await self.session.scalar(stmt) - async def find_list_by_collection_id(self, collection_id: int): + async def find_list_by_collection_id(self, collection_id: int) -> sqlalchemy.ScalarResult[TitleIndexModel]: stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.collection_id == collection_id) - return await self.session.scalars(stmt) \ No newline at end of file + return await self.session.scalars(stmt) + + async def get_need_update_index_list(self, collection_id: int) -> list[TitleIndexModel]: + page_list = await self.find_list_by_collection_id(collection_id) + result: list[TitleIndexModel] = [] + for page_info in page_list: + if page_info.indexed_rev_id != page_info.latest_rev_id: + result.append(page_info) + + return result \ No newline at end of file diff --git a/api/route.py b/api/route.py index e6a2eda..cb45799 100644 --- a/api/route.py +++ b/api/route.py @@ -28,6 +28,7 @@ def init(app: web.Application): web.route('*', '/embedding_search/index_page', EmbeddingSearch.index_page), web.route('*', '/embedding_search/search', EmbeddingSearch.search), + web.route('POST', '/sys/embedding_search/title/update', EmbeddingSearch.sys_update_title_info), web.route('*', '/chatcomplete/conversation_chunk/list', ChatComplete.get_conversation_chunk_list), web.route('*', '/chatcomplete/conversation_chunk/info', ChatComplete.get_conversation_chunk), diff --git a/maintenance/base.py b/maintenance/base.py new file mode 100644 index 0000000..bb81f7d --- /dev/null +++ b/maintenance/base.py @@ -0,0 +1,4 @@ +import sys +import pathlib + +sys.path.append(str(pathlib.Path(__file__).parent.parent)) \ No newline at end of file diff --git a/maintenance/update_title_index.py b/maintenance/update_title_index.py new file mode 100644 index 0000000..e2fa1f9 --- /dev/null +++ b/maintenance/update_title_index.py @@ -0,0 +1,35 @@ +import asyncio +import base as _ +import local +from service.database import DatabaseService + +from service.embedding_search import EmbeddingSearchService +from service.mediawiki_api import MediaWikiApi + +async def main(): + dbs = await DatabaseService.create() + mw_api = MediaWikiApi.create() + + continue_key = None + while True: + page_res = await mw_api.get_all_pages(continue_key) + + title_list = page_res["title_list"] + + for page_title in title_list: + print("Indexing %s" % page_title) + async with EmbeddingSearchService(dbs, page_title) as embedding_search: + await embedding_search.update_title_index(True) + + if not page_res["continue_key"]: + break + + continue_key = page_res["continue_key"] + + print("Done") + + await local.noawait.end() + await asyncio.sleep(1) + +if __name__ == '__main__': + local.loop.run_until_complete(main()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 523fa00..f5eb69d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ pgvector==0.1.6 websockets==11.0 PyJWT==2.6.0 asyncpg-stubs==0.27.0 -sqlalchemy==2.0.9 +sqlalchemy==2.0.17 aiohttp-sse-client2==0.3.0 OpenCC==1.1.6 event-emitter-asyncio==1.0.4 diff --git a/service/chat_complete.py b/service/chat_complete.py index 3dc12e0..48aca2e 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -196,12 +196,12 @@ class ChatCompleteService: self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk) else: # 创建新对话 - title_info = self.embedding_search.title_info + title_info = self.embedding_search.title_index self.conversation_info = ConversationModel( user_id=self.user_id, module="chatcomplete", - page_id=title_info["page_id"], - rev_id=title_info["rev_id"], + page_id=title_info.page_id, + rev_id=title_info.latest_rev_id, ) self.conversation_info = await self.conversation_helper.add( self.conversation_info, diff --git a/service/embedding_search.py b/service/embedding_search.py index 5507ce0..a1d5839 100644 --- a/service/embedding_search.py +++ b/service/embedding_search.py @@ -1,6 +1,11 @@ from __future__ import annotations from typing import Optional, TypedDict -from api.model.embedding_search.title_collection import TitleCollectionHelper, TitleCollectionModel + +import sqlalchemy +from api.model.embedding_search.title_collection import ( + TitleCollectionHelper, + TitleCollectionModel, +) from api.model.embedding_search.title_index import TitleIndexHelper, TitleIndexModel from api.model.embedding_search.page_index import PageIndexHelper from service.database import DatabaseService @@ -9,14 +14,17 @@ from service.openai_api import OpenAIApi from service.tiktoken import TikTokenService from utils.wiki import getWikiSentences + class EmbeddingRunningException(Exception): pass + class EmbeddingSearchArgs(TypedDict): limit: Optional[int] in_collection: Optional[bool] distance_limit: Optional[float] + class EmbeddingSearchService: indexing_page_ids: list[int] = [] @@ -38,7 +46,7 @@ class EmbeddingSearchService: self.page_id: Optional[int] = None self.collection_id: Optional[int] = None - self.title_info: Optional[TitleIndexModel] = None + self.title_index: Optional[TitleIndexModel] = None self.collection_info: Optional[TitleCollectionModel] = None self.page_info: dict = None @@ -50,17 +58,22 @@ class EmbeddingSearchService: await self.title_index_helper.__aenter__() await self.title_collection_helper.__aenter__() - self.title_info = await self.title_index_helper.find_by_title(self.title) - if self.title_info is not None: - self.page_id = self.title_info.page_id - self.collection_id = self.title_info.collection_id - - self.page_index = PageIndexHelper( - self.dbs, self.collection_id, self.page_id) + self.title_index = await self.title_index_helper.find_by_title(self.title) + if self.title_index is None: + # Title may changed, get page info from page_id + await self.load_page_info() + self.title_index = await self.title_index_helper.find_by_page_id( + self.page_info["pageid"] + ) + self.page_id = self.page_info["pageid"] + else: + self.page_id = self.title_index.page_id + self.collection_id = self.title_index.collection_id + self.page_index = PageIndexHelper(self.dbs, self.collection_id) await self.page_index.__aenter__() return self - + async def __aexit__(self, exc_type, exc, tb): await self.title_index_helper.__aexit__(exc_type, exc, tb) await self.title_collection_helper.__aexit__(exc_type, exc, tb) @@ -68,7 +81,7 @@ class EmbeddingSearchService: if self.page_index is not None: await self.page_index.__aexit__(exc_type, exc, tb) - async def page_index_exists(self, check_table = True): + async def page_index_exists(self, check_table=True): if check_table: return self.page_index and await self.page_index.table_exists() else: @@ -78,47 +91,106 @@ class EmbeddingSearchService: if self.page_info is None or reload: self.page_info = await self.mwapi.get_page_info(self.title) - async def should_update_page_index(self): - await self.load_page_info() + async def should_update_page_index(self, remote_update=False): + if not remote_update: + if self.title_index is None: + return True + return self.title_index.indexed_rev_id != self.title_index.latest_rev_id + else: + await self.load_page_info() - if (self.title_info is not None and await self.page_index_exists() and - self.title_info.title == self.page_info["title"] and self.title_info.rev_id == self.page_info["lastrevid"]): - # Not changed - return False - - return True + if ( + self.title_index is not None + and await self.page_index_exists() + and self.title_index.indexed_rev_id == self.page_info["lastrevid"] + ): + # Not changed + return False - async def prepare_update_index(self): - # Check rev_id - await self.load_page_info() + return True - if not await self.should_update_page_index(): + async def update_title_index(self, remote_update=False): + if not await self.should_update_page_index(remote_update): return False + await self.load_page_info() + self.page_id = self.page_info["pageid"] if self.page_id in self.indexing_page_ids: raise EmbeddingRunningException("Page index is running now") - # Create collection - self.collection_info = await self.title_collection_helper.find_by_title(self.base_title) + self.title: str = self.page_info["title"] + self.base_title = self.title.split("/")[0] + + # Find collection by base title + self.collection_info = await self.title_collection_helper.find_by_title( + self.base_title + ) if self.collection_info is None: - self.collection_id = await self.title_collection_helper.add(self.base_title) - if self.collection_id is None: + # Create collection + self.collection_info = await self.title_collection_helper.add( + self.base_title + ) + if self.collection_info is None: raise Exception("Failed to create title collection") + self.collection_id = self.collection_info.id + + if self.title_index == None: + # Create title record + self.title_index = TitleIndexModel( + title=self.page_info["title"], + page_id=self.page_id, + indexed_rev_id=None, + latest_rev_id=self.page_info["lastrevid"], + collection_id=self.collection_id, + embedding=None, + ) + self.title_index = await self.title_index_helper.add(self.title_index) + if self.title_index is None: + raise Exception("Failed to create title index") else: - self.collection_id = self.collection_info.id + self.title_index.latest_rev_id = self.page_info["lastrevid"] + # Title changed, remove embedding + # Title sha1 will be updated by model helper + if self.title_index.title != self.page_info["title"]: + self.title_index.title = self.page_info["title"] + self.title_index.embedding = None + + # Collection changed, remove old index + if self.collection_id != self.title_index.collection_id: + async with PageIndexHelper(self.dbs, self.title_index.collection_id) as old_page_index: + await self.page_index.init_table() + old_page_index.remove_by_page_id(self.page_id) + self.title_index.collection_id = self.collection_id + + await self.title_index_helper.update(self.title_index) + + # Update collection main page id + if ( + self.title == self.collection_info.title + and self.page_id != self.collection_info.page_id + ): + await self.title_collection_helper.set_main_page_id( + self.base_title, self.page_id + ) - self.page_index = PageIndexHelper( - self.dbs, self.collection_id, self.page_id) + if self.page_index: + await self.page_index.__aexit__(None, None, None) + self.page_index = PageIndexHelper(self.dbs, self.collection_id) await self.page_index.__aenter__() await self.page_index.init_table() + async def prepare_update_index(self): + await self.update_title_index() + page_content = await self.mwapi.parse_page(self.title) self.sentences = getWikiSentences(page_content) - self.unindexed_docs = await self.page_index.get_unindexed_doc(self.sentences, with_temporary=False) + self.unindexed_docs = await self.page_index.get_unindexed_doc( + self.sentences, self.page_id, with_temporary=False + ) return True @@ -153,15 +225,17 @@ class EmbeddingSearchService: await on_progress(indexed_docs, len(self.unindexed_docs)) async def embedding_doc(doc_chunk): - (doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk, on_embedding_progress) - await self.page_index.index_doc(doc_chunk) + (doc_chunk, token_usage) = await self.openai_api.get_embeddings( + doc_chunk, on_embedding_progress + ) + await self.page_index.index_doc(doc_chunk, self.page_id) return token_usage if len(self.unindexed_docs) > 0: if on_progress is not None: await on_progress(0, len(self.unindexed_docs)) - + for doc in self.unindexed_docs: chunk_len += len(doc) @@ -181,55 +255,34 @@ class EmbeddingSearchService: if on_progress is not None: await on_progress(len(self.unindexed_docs), len(self.unindexed_docs)) - await self.page_index.remove_outdated_doc(self.sentences) + await self.page_index.remove_outdated_doc(self.sentences, self.page_id) # Update database - if self.title_info is None: + # This task may take a long time, refresh model to retrieve latest data + self.title_index = await self.title_index_helper.refresh(self.title_index) + + self.title_index.indexed_rev_id = self.page_info["lastrevid"] + + # Update title embedding + if await self.title_index.awaitable_attrs.embedding is None: doc_chunk = [{"text": self.title}] (doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk) total_token_usage += token_usage embedding = doc_chunk[0]["embedding"] + self.title_index.embedding = embedding - self.title_info = TitleIndexModel( - title=self.page_info["title"], - page_id=self.page_id, - rev_id=self.page_info["lastrevid"], - latest_rev_id=self.page_info["lastrevid"], - collection_id=self.collection_id, - embedding=embedding - ) - res = await self.title_index_helper.add(self.title_info) - if res: - self.title_info = res - else: - if self.title != self.page_info["title"]: - self.title = self.page_info["title"] - - doc_chunk = [{"text": self.title}] - (doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk) - total_token_usage += token_usage - - embedding = doc_chunk[0]["embedding"] - - self.title_info.title = self.title - self.title_info.rev_id = self.page_info["lastrevid"] - self.title_info.latest_rev_id = self.page_info["lastrevid"] - self.title_info.collection_id = self.collection_id - self.title_info.embedding = embedding - else: - self.title_info.rev_id = self.page_info["lastrevid"] - self.title_info.latest_rev_id = self.page_info["lastrevid"] - - await self.title_index_helper.update(self.title_info) - - if (self.collection_info is None or - (self.base_title == self.collection_info.title and self.page_id != self.collection_info.page_id)): - await self.title_collection_helper.set_page_id(self.base_title, self.page_id) + await self.title_index_helper.update(self.title_index) return total_token_usage - async def search(self, query: str, limit: int = 10, in_collection: bool = False, distance_limit: float = 0.6): + async def search( + self, + query: str, + limit: int = 10, + in_collection: bool = False, + distance_limit: float = 0.6, + ): if self.page_index is None: raise Exception("Page index is not initialized") @@ -240,7 +293,9 @@ class EmbeddingSearchService: if query_embedding is None: return [], token_usage - res = await self.page_index.search_text_embedding(query_embedding, in_collection, limit) + res = await self.page_index.search_text_embedding( + query_embedding, in_collection, limit, self.page_id + ) if res: filtered = [] for one in res: diff --git a/service/mediawiki_api.py b/service/mediawiki_api.py index f7c45c9..61b913a 100644 --- a/service/mediawiki_api.py +++ b/service/mediawiki_api.py @@ -35,6 +35,10 @@ class MediaWikiUserNoEnoughPointsException(Exception): def __str__(self) -> str: return self.info + +class GetAllPagesResponse(TypedDict): + title_list: list[str] + continue_key: Optional[str] class ChatCompleteGetPointUsageResponse(TypedDict): point_cost: int @@ -125,6 +129,38 @@ class MediaWikiApi: return ret + async def get_all_pages(self, continue_key: Optional[str] = None, start_title: Optional[str] = None) -> GetAllPagesResponse: + async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: + params = { + "action": "query", + "format": "json", + "formatversion": "2", + "list": "allpages", + "apnamespace": 0, + "apfilterredir": "nonredirects", + "aplimit": 100, + } + if continue_key is not None: + params["apcontinue"] = continue_key + if start_title is not None: + params["apfrom"] = start_title + + async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp: + data = await resp.json() + if "error" in data: + raise MediaWikiApiException(data["error"]["info"], data["error"]["code"]) + + title_list = [] + for page in data["query"]["allpages"]: + title_list.append(page["title"]) + + ret = GetAllPagesResponse(title_list=title_list, continue_key=None) + + if "continue" in data and "apcontinue" in data["continue"]: + ret["continue_key"] = data["continue"]["apcontinue"] + + return ret + async def is_logged_in(self): async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: params = {