diff --git a/api/controller/ChatComplete.py b/api/controller/ChatComplete.py index 9b5a9c0..ef2c949 100644 --- a/api/controller/ChatComplete.py +++ b/api/controller/ChatComplete.py @@ -194,8 +194,8 @@ class ChatComplete: # Remove message after selected message split_message_pos = None for i in range(0, len(new_chunk.message_data)): - msg_data = new_chunk.message_data[i] - if msg_data["id"] == msg_id: + msg_data: dict = new_chunk.message_data[i] + if msg_data.get("id") == msg_id: split_message_pos = i break diff --git a/api/model/embedding_search/title_index.py b/api/model/embedding_search/title_index.py index 2d1056e..cce023f 100644 --- a/api/model/embedding_search/title_index.py +++ b/api/model/embedding_search/title_index.py @@ -1,14 +1,15 @@ +from __future__ import annotations import hashlib import asyncpg import numpy as np from pgvector.sqlalchemy import Vector from pgvector.asyncpg import register_vector import sqlalchemy -from sqlalchemy.orm import mapped_column, relationship, Mapped +from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred from sqlalchemy.ext.asyncio import AsyncEngine import config -from api.model.base import BaseModel +from api.model.base import BaseHelper, BaseModel from service.database import DatabaseService class TitleIndexModel(BaseModel): @@ -19,14 +20,15 @@ class TitleIndexModel(BaseModel): title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) - rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) - embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)) + rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer) + latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer) + embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))) embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding, postgresql_using='ivfflat', postgresql_ops={'embedding': 'vector_cosine_ops'}) -class TitleIndexHelper: +class TitleIndexHelper(BaseHelper): __tablename__ = "embedding_search_title_index" columns = [ @@ -36,12 +38,12 @@ class TitleIndexHelper: "page_id", "collection_id", "rev_id", + "latest_rev_id" "embedding", ] def __init__(self, dbs: DatabaseService): - self.dbs = dbs - self.initialized = False + super().__init__(dbs) async def __aenter__(self): if not self.initialized: @@ -50,12 +52,11 @@ class TitleIndexHelper: await register_vector(self.dbi) - self.initialized = True - - return self + await super().__aenter__() async def __aexit__(self, exc_type, exc, tb): await self.dbpool.__aexit__(exc_type, exc, tb) + await super().__aexit__(exc_type, exc, tb) def get_columns(self, exclude=[]): if len(exclude) == 0: @@ -66,17 +67,14 @@ class TitleIndexHelper: """ Add a title to the index """ - async def add(self, title: str, page_id: int, rev_id: int, collection_id: int, embedding: np.ndarray): - title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() - ret = await self.dbi.fetchrow("SELECT * FROM embedding_search_title_index WHERE sha1 = $1", title_sha1) - - if ret is None: - new_id = await self.dbi.fetchval("""INSERT INTO embedding_search_title_index - (sha1, title, page_id, rev_id, collection_id, embedding) - VALUES ($1, $2, $3, $4, $5, $6) - RETURNING id""", - title_sha1, title, page_id, rev_id, collection_id, embedding, column=0) - return new_id + async def add(self, obj: TitleIndexModel): + obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest() + + if await self.find_by_sha1(obj.sha1) is None: + self.session.add(obj) + await self.session.commit() + await self.session.refresh(obj) + return obj return False @@ -85,25 +83,15 @@ class TitleIndexHelper: """ async def remove(self, title: str): title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() - await self.dbi.execute("DELETE FROM embedding_search_title_index WHERE sha1 = $1", title_sha1) + stmt = sqlalchemy.delete(TitleIndexModel).where(TitleIndexModel.sha1 == title_sha1) + await self.session.execute(stmt) - """ - Update the indexed revision id of a title - """ - async def update_rev_id(self, page_id: int, rev_id: int): - await self.dbi.execute("UPDATE embedding_search_title_index SET rev_id = $1 WHERE page_id = $2", rev_id, page_id) - - """ - Update title data - """ - async def update_title_data(self, page_id: int, title: str, rev_id: int, collection_id: int, embedding: np.ndarray): - if collection_page_id is None: - collection_page_id = page_id - - await self.dbi.execute("""UPDATE embedding_search_title_index - SET title = $1, rev_id = $2, collection_id = $3, embedding = $4 - WHERE page_id = $5""", - title, rev_id, collection_id, embedding, page_id) + async def update(self, obj: TitleIndexModel): + obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest() + await self.session.merge(obj) + await self.session.commit() + await self.session.refresh(obj) + return obj """ Search for titles by consine similary @@ -120,43 +108,18 @@ class TitleIndexHelper: """ Find a title in the index """ - async def find_by_title(self, title: str, with_embedding=False): - title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() - - if with_embedding: - columns = self.get_columns() - else: - columns = self.get_columns(exclude=["embedding"]) - - ret = await self.dbi.fetchrow( - "SELECT %s FROM embedding_search_title_index WHERE sha1 = $1" % columns, - title_sha1 - ) - - return ret - - async def find_by_page_id(self, page_id: int, with_embedding=False): - if with_embedding: - columns = self.get_columns() - else: - columns = self.get_columns(exclude=["embedding"]) - - ret = await self.dbi.fetchrow( - "SELECT %s FROM embedding_search_title_index WHERE page_id = $1" % columns, - page_id - ) - - return ret - - async def find_by_collection_id(self, collection_id: int, with_embedding=False): - if with_embedding: - columns = self.get_columns() - else: - columns = self.get_columns(exclude=["embedding"]) - - ret = await self.dbi.fetch( - "SELECT %s FROM embedding_search_title_index WHERE collection_id = $1" % columns, - collection_id - ) - - return ret + async def find_by_title(self, title: str): + sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() + return await self.find_by_sha1(sha1) + + async def find_by_sha1(self, sha1: str): + stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.sha1 == sha1) + return await self.session.scalar(stmt) + + async def find_by_page_id(self, page_id: int): + stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.page_id == page_id) + return await self.session.scalar(stmt) + + async def find_list_by_collection_id(self, collection_id: int): + stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.collection_id == collection_id) + return await self.session.scalars(stmt) \ No newline at end of file diff --git a/service/chat_complete.py b/service/chat_complete.py index f634536..541e19e 100644 --- a/service/chat_complete.py +++ b/service/chat_complete.py @@ -157,11 +157,11 @@ class ChatCompleteService: edit_message_pos = None old_tokens = 0 for i in range(0, len(self.conversation_chunk.message_data)): - msg_data = self.conversation_chunk.message_data[i] - if msg_data["id"] == edit_msg_id: + msg_data: dict = self.conversation_chunk.message_data[i] + if msg_data.get("id") == edit_msg_id: edit_message_pos = i break - if "tokens" in msg_data and msg_data["tokens"]: + if "tokens" in msg_data and msg_data["tokens"] is not None: old_tokens += msg_data["tokens"] if edit_message_pos: diff --git a/service/embedding_search.py b/service/embedding_search.py index fa50ef5..5507ce0 100644 --- a/service/embedding_search.py +++ b/service/embedding_search.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Optional, TypedDict from api.model.embedding_search.title_collection import TitleCollectionHelper, TitleCollectionModel -from api.model.embedding_search.title_index import TitleIndexHelper +from api.model.embedding_search.title_index import TitleIndexHelper, TitleIndexModel from api.model.embedding_search.page_index import PageIndexHelper from service.database import DatabaseService from service.mediawiki_api import MediaWikiApi @@ -26,8 +26,8 @@ class EmbeddingSearchService: self.title = title self.base_title = title.split("/")[0] - self.title_index = TitleIndexHelper(dbs) - self.title_collection = TitleCollectionHelper(dbs) + self.title_index_helper = TitleIndexHelper(dbs) + self.title_collection_helper = TitleCollectionHelper(dbs) self.page_index: PageIndexHelper = None self.tiktoken: TikTokenService = None @@ -35,11 +35,11 @@ class EmbeddingSearchService: self.mwapi = MediaWikiApi.create() self.openai_api = OpenAIApi.create() - self.page_id: int = None - self.collection_id: int = None + self.page_id: Optional[int] = None + self.collection_id: Optional[int] = None - self.title_info: dict = None - self.collection_info: TitleCollectionModel = None + self.title_info: Optional[TitleIndexModel] = None + self.collection_info: Optional[TitleCollectionModel] = None self.page_info: dict = None self.unindexed_docs: list = None @@ -47,13 +47,13 @@ class EmbeddingSearchService: async def __aenter__(self): self.tiktoken = await TikTokenService.create() - await self.title_index.__aenter__() - await self.title_collection.__aenter__() + await self.title_index_helper.__aenter__() + await self.title_collection_helper.__aenter__() - self.title_info = await self.title_index.find_by_title(self.title) + self.title_info = await self.title_index_helper.find_by_title(self.title) if self.title_info is not None: - self.page_id = self.title_info["page_id"] - self.collection_id = self.title_info["collection_id"] + self.page_id = self.title_info.page_id + self.collection_id = self.title_info.collection_id self.page_index = PageIndexHelper( self.dbs, self.collection_id, self.page_id) @@ -62,8 +62,8 @@ class EmbeddingSearchService: return self async def __aexit__(self, exc_type, exc, tb): - await self.title_index.__aexit__(exc_type, exc, tb) - await self.title_collection.__aexit__(exc_type, exc, tb) + await self.title_index_helper.__aexit__(exc_type, exc, tb) + await self.title_collection_helper.__aexit__(exc_type, exc, tb) if self.page_index is not None: await self.page_index.__aexit__(exc_type, exc, tb) @@ -82,7 +82,7 @@ class EmbeddingSearchService: await self.load_page_info() if (self.title_info is not None and await self.page_index_exists() and - self.title_info["title"] == self.page_info["title"] and self.title_info["rev_id"] == self.page_info["lastrevid"]): + self.title_info.title == self.page_info["title"] and self.title_info.rev_id == self.page_info["lastrevid"]): # Not changed return False @@ -101,9 +101,9 @@ class EmbeddingSearchService: raise EmbeddingRunningException("Page index is running now") # Create collection - self.collection_info = await self.title_collection.find_by_title(self.base_title) + self.collection_info = await self.title_collection_helper.find_by_title(self.base_title) if self.collection_info is None: - self.collection_id = await self.title_collection.add(self.base_title) + self.collection_id = await self.title_collection_helper.add(self.base_title) if self.collection_id is None: raise Exception("Failed to create title collection") else: @@ -191,11 +191,17 @@ class EmbeddingSearchService: embedding = doc_chunk[0]["embedding"] - await self.title_index.add(self.page_info["title"], - self.page_id, - self.page_info["lastrevid"], - self.collection_id, - embedding) + self.title_info = TitleIndexModel( + title=self.page_info["title"], + page_id=self.page_id, + rev_id=self.page_info["lastrevid"], + latest_rev_id=self.page_info["lastrevid"], + collection_id=self.collection_id, + embedding=embedding + ) + res = await self.title_index_helper.add(self.title_info) + if res: + self.title_info = res else: if self.title != self.page_info["title"]: self.title = self.page_info["title"] @@ -206,17 +212,20 @@ class EmbeddingSearchService: embedding = doc_chunk[0]["embedding"] - await self.title_index.update_title_data(self.page_id, - self.title, - self.page_info["lastrevid"], - self.collection_id, - embedding) + self.title_info.title = self.title + self.title_info.rev_id = self.page_info["lastrevid"] + self.title_info.latest_rev_id = self.page_info["lastrevid"] + self.title_info.collection_id = self.collection_id + self.title_info.embedding = embedding else: - await self.title_index.update_rev_id(self.title, self.page_info["lastrevid"]) + self.title_info.rev_id = self.page_info["lastrevid"] + self.title_info.latest_rev_id = self.page_info["lastrevid"] + + await self.title_index_helper.update(self.title_info) if (self.collection_info is None or (self.base_title == self.collection_info.title and self.page_id != self.collection_info.page_id)): - await self.title_collection.set_page_id(self.base_title, self.page_id) + await self.title_collection_helper.set_page_id(self.base_title, self.page_id) return total_token_usage diff --git a/test/base.py b/test/base.py index bb81f7d..a184749 100644 --- a/test/base.py +++ b/test/base.py @@ -1,4 +1,7 @@ import sys import pathlib -sys.path.append(str(pathlib.Path(__file__).parent.parent)) \ No newline at end of file +sys.path.append(str(pathlib.Path(__file__).parent.parent)) + +import config +config.DEBUG = True \ No newline at end of file diff --git a/test/title_index.py b/test/title_index.py new file mode 100644 index 0000000..a339782 --- /dev/null +++ b/test/title_index.py @@ -0,0 +1,23 @@ +import asyncio +import base + +from sqlalchemy import select +from api.model.embedding_search.title_index import TitleIndexModel +import local +from service.database import DatabaseService + +from service.embedding_search import EmbeddingSearchService + +async def main(): + dbs = await DatabaseService.create() + + async with dbs.create_session() as session: + stmt = select(TitleIndexModel).where(TitleIndexModel.title == "代号:曙光的世界/黄昏的阿瓦隆") + res = await session.scalar(stmt) + print(res.__dict__) + + await asyncio.sleep(0.5) + await local.noawait.end() + +if __name__ == '__main__': + local.loop.run_until_complete(main()) \ No newline at end of file