将TitleIndexHelper改为ORM调用,修复编辑时报错

master
落雨楓 2 years ago
parent 003a9a7948
commit 99e23f5281

@ -194,8 +194,8 @@ class ChatComplete:
# Remove message after selected message # Remove message after selected message
split_message_pos = None split_message_pos = None
for i in range(0, len(new_chunk.message_data)): for i in range(0, len(new_chunk.message_data)):
msg_data = new_chunk.message_data[i] msg_data: dict = new_chunk.message_data[i]
if msg_data["id"] == msg_id: if msg_data.get("id") == msg_id:
split_message_pos = i split_message_pos = i
break break

@ -1,14 +1,15 @@
from __future__ import annotations
import hashlib import hashlib
import asyncpg import asyncpg
import numpy as np import numpy as np
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from pgvector.asyncpg import register_vector from pgvector.asyncpg import register_vector
import sqlalchemy import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred
from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncEngine
import config import config
from api.model.base import BaseModel from api.model.base import BaseHelper, BaseModel
from service.database import DatabaseService from service.database import DatabaseService
class TitleIndexModel(BaseModel): class TitleIndexModel(BaseModel):
@ -19,14 +20,15 @@ class TitleIndexModel(BaseModel):
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)) latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)))
embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding, embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding,
postgresql_using='ivfflat', postgresql_using='ivfflat',
postgresql_ops={'embedding': 'vector_cosine_ops'}) postgresql_ops={'embedding': 'vector_cosine_ops'})
class TitleIndexHelper: class TitleIndexHelper(BaseHelper):
__tablename__ = "embedding_search_title_index" __tablename__ = "embedding_search_title_index"
columns = [ columns = [
@ -36,12 +38,12 @@ class TitleIndexHelper:
"page_id", "page_id",
"collection_id", "collection_id",
"rev_id", "rev_id",
"latest_rev_id"
"embedding", "embedding",
] ]
def __init__(self, dbs: DatabaseService): def __init__(self, dbs: DatabaseService):
self.dbs = dbs super().__init__(dbs)
self.initialized = False
async def __aenter__(self): async def __aenter__(self):
if not self.initialized: if not self.initialized:
@ -50,12 +52,11 @@ class TitleIndexHelper:
await register_vector(self.dbi) await register_vector(self.dbi)
self.initialized = True await super().__aenter__()
return self
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
await self.dbpool.__aexit__(exc_type, exc, tb) await self.dbpool.__aexit__(exc_type, exc, tb)
await super().__aexit__(exc_type, exc, tb)
def get_columns(self, exclude=[]): def get_columns(self, exclude=[]):
if len(exclude) == 0: if len(exclude) == 0:
@ -66,17 +67,14 @@ class TitleIndexHelper:
""" """
Add a title to the index Add a title to the index
""" """
async def add(self, title: str, page_id: int, rev_id: int, collection_id: int, embedding: np.ndarray): async def add(self, obj: TitleIndexModel):
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest()
ret = await self.dbi.fetchrow("SELECT * FROM embedding_search_title_index WHERE sha1 = $1", title_sha1)
if ret is None: if await self.find_by_sha1(obj.sha1) is None:
new_id = await self.dbi.fetchval("""INSERT INTO embedding_search_title_index self.session.add(obj)
(sha1, title, page_id, rev_id, collection_id, embedding) await self.session.commit()
VALUES ($1, $2, $3, $4, $5, $6) await self.session.refresh(obj)
RETURNING id""", return obj
title_sha1, title, page_id, rev_id, collection_id, embedding, column=0)
return new_id
return False return False
@ -85,25 +83,15 @@ class TitleIndexHelper:
""" """
async def remove(self, title: str): async def remove(self, title: str):
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
await self.dbi.execute("DELETE FROM embedding_search_title_index WHERE sha1 = $1", title_sha1) stmt = sqlalchemy.delete(TitleIndexModel).where(TitleIndexModel.sha1 == title_sha1)
await self.session.execute(stmt)
""" async def update(self, obj: TitleIndexModel):
Update the indexed revision id of a title obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest()
""" await self.session.merge(obj)
async def update_rev_id(self, page_id: int, rev_id: int): await self.session.commit()
await self.dbi.execute("UPDATE embedding_search_title_index SET rev_id = $1 WHERE page_id = $2", rev_id, page_id) await self.session.refresh(obj)
return obj
"""
Update title data
"""
async def update_title_data(self, page_id: int, title: str, rev_id: int, collection_id: int, embedding: np.ndarray):
if collection_page_id is None:
collection_page_id = page_id
await self.dbi.execute("""UPDATE embedding_search_title_index
SET title = $1, rev_id = $2, collection_id = $3, embedding = $4
WHERE page_id = $5""",
title, rev_id, collection_id, embedding, page_id)
""" """
Search for titles by consine similary Search for titles by consine similary
@ -120,43 +108,18 @@ class TitleIndexHelper:
""" """
Find a title in the index Find a title in the index
""" """
async def find_by_title(self, title: str, with_embedding=False): async def find_by_title(self, title: str):
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest() sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
return await self.find_by_sha1(sha1)
if with_embedding: async def find_by_sha1(self, sha1: str):
columns = self.get_columns() stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.sha1 == sha1)
else: return await self.session.scalar(stmt)
columns = self.get_columns(exclude=["embedding"])
ret = await self.dbi.fetchrow( async def find_by_page_id(self, page_id: int):
"SELECT %s FROM embedding_search_title_index WHERE sha1 = $1" % columns, stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.page_id == page_id)
title_sha1 return await self.session.scalar(stmt)
)
return ret async def find_list_by_collection_id(self, collection_id: int):
stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.collection_id == collection_id)
async def find_by_page_id(self, page_id: int, with_embedding=False): return await self.session.scalars(stmt)
if with_embedding:
columns = self.get_columns()
else:
columns = self.get_columns(exclude=["embedding"])
ret = await self.dbi.fetchrow(
"SELECT %s FROM embedding_search_title_index WHERE page_id = $1" % columns,
page_id
)
return ret
async def find_by_collection_id(self, collection_id: int, with_embedding=False):
if with_embedding:
columns = self.get_columns()
else:
columns = self.get_columns(exclude=["embedding"])
ret = await self.dbi.fetch(
"SELECT %s FROM embedding_search_title_index WHERE collection_id = $1" % columns,
collection_id
)
return ret

@ -157,11 +157,11 @@ class ChatCompleteService:
edit_message_pos = None edit_message_pos = None
old_tokens = 0 old_tokens = 0
for i in range(0, len(self.conversation_chunk.message_data)): for i in range(0, len(self.conversation_chunk.message_data)):
msg_data = self.conversation_chunk.message_data[i] msg_data: dict = self.conversation_chunk.message_data[i]
if msg_data["id"] == edit_msg_id: if msg_data.get("id") == edit_msg_id:
edit_message_pos = i edit_message_pos = i
break break
if "tokens" in msg_data and msg_data["tokens"]: if "tokens" in msg_data and msg_data["tokens"] is not None:
old_tokens += msg_data["tokens"] old_tokens += msg_data["tokens"]
if edit_message_pos: if edit_message_pos:

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, TypedDict from typing import Optional, TypedDict
from api.model.embedding_search.title_collection import TitleCollectionHelper, TitleCollectionModel from api.model.embedding_search.title_collection import TitleCollectionHelper, TitleCollectionModel
from api.model.embedding_search.title_index import TitleIndexHelper from api.model.embedding_search.title_index import TitleIndexHelper, TitleIndexModel
from api.model.embedding_search.page_index import PageIndexHelper from api.model.embedding_search.page_index import PageIndexHelper
from service.database import DatabaseService from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi from service.mediawiki_api import MediaWikiApi
@ -26,8 +26,8 @@ class EmbeddingSearchService:
self.title = title self.title = title
self.base_title = title.split("/")[0] self.base_title = title.split("/")[0]
self.title_index = TitleIndexHelper(dbs) self.title_index_helper = TitleIndexHelper(dbs)
self.title_collection = TitleCollectionHelper(dbs) self.title_collection_helper = TitleCollectionHelper(dbs)
self.page_index: PageIndexHelper = None self.page_index: PageIndexHelper = None
self.tiktoken: TikTokenService = None self.tiktoken: TikTokenService = None
@ -35,11 +35,11 @@ class EmbeddingSearchService:
self.mwapi = MediaWikiApi.create() self.mwapi = MediaWikiApi.create()
self.openai_api = OpenAIApi.create() self.openai_api = OpenAIApi.create()
self.page_id: int = None self.page_id: Optional[int] = None
self.collection_id: int = None self.collection_id: Optional[int] = None
self.title_info: dict = None self.title_info: Optional[TitleIndexModel] = None
self.collection_info: TitleCollectionModel = None self.collection_info: Optional[TitleCollectionModel] = None
self.page_info: dict = None self.page_info: dict = None
self.unindexed_docs: list = None self.unindexed_docs: list = None
@ -47,13 +47,13 @@ class EmbeddingSearchService:
async def __aenter__(self): async def __aenter__(self):
self.tiktoken = await TikTokenService.create() self.tiktoken = await TikTokenService.create()
await self.title_index.__aenter__() await self.title_index_helper.__aenter__()
await self.title_collection.__aenter__() await self.title_collection_helper.__aenter__()
self.title_info = await self.title_index.find_by_title(self.title) self.title_info = await self.title_index_helper.find_by_title(self.title)
if self.title_info is not None: if self.title_info is not None:
self.page_id = self.title_info["page_id"] self.page_id = self.title_info.page_id
self.collection_id = self.title_info["collection_id"] self.collection_id = self.title_info.collection_id
self.page_index = PageIndexHelper( self.page_index = PageIndexHelper(
self.dbs, self.collection_id, self.page_id) self.dbs, self.collection_id, self.page_id)
@ -62,8 +62,8 @@ class EmbeddingSearchService:
return self return self
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
await self.title_index.__aexit__(exc_type, exc, tb) await self.title_index_helper.__aexit__(exc_type, exc, tb)
await self.title_collection.__aexit__(exc_type, exc, tb) await self.title_collection_helper.__aexit__(exc_type, exc, tb)
if self.page_index is not None: if self.page_index is not None:
await self.page_index.__aexit__(exc_type, exc, tb) await self.page_index.__aexit__(exc_type, exc, tb)
@ -82,7 +82,7 @@ class EmbeddingSearchService:
await self.load_page_info() await self.load_page_info()
if (self.title_info is not None and await self.page_index_exists() and if (self.title_info is not None and await self.page_index_exists() and
self.title_info["title"] == self.page_info["title"] and self.title_info["rev_id"] == self.page_info["lastrevid"]): self.title_info.title == self.page_info["title"] and self.title_info.rev_id == self.page_info["lastrevid"]):
# Not changed # Not changed
return False return False
@ -101,9 +101,9 @@ class EmbeddingSearchService:
raise EmbeddingRunningException("Page index is running now") raise EmbeddingRunningException("Page index is running now")
# Create collection # Create collection
self.collection_info = await self.title_collection.find_by_title(self.base_title) self.collection_info = await self.title_collection_helper.find_by_title(self.base_title)
if self.collection_info is None: if self.collection_info is None:
self.collection_id = await self.title_collection.add(self.base_title) self.collection_id = await self.title_collection_helper.add(self.base_title)
if self.collection_id is None: if self.collection_id is None:
raise Exception("Failed to create title collection") raise Exception("Failed to create title collection")
else: else:
@ -191,11 +191,17 @@ class EmbeddingSearchService:
embedding = doc_chunk[0]["embedding"] embedding = doc_chunk[0]["embedding"]
await self.title_index.add(self.page_info["title"], self.title_info = TitleIndexModel(
self.page_id, title=self.page_info["title"],
self.page_info["lastrevid"], page_id=self.page_id,
self.collection_id, rev_id=self.page_info["lastrevid"],
embedding) latest_rev_id=self.page_info["lastrevid"],
collection_id=self.collection_id,
embedding=embedding
)
res = await self.title_index_helper.add(self.title_info)
if res:
self.title_info = res
else: else:
if self.title != self.page_info["title"]: if self.title != self.page_info["title"]:
self.title = self.page_info["title"] self.title = self.page_info["title"]
@ -206,17 +212,20 @@ class EmbeddingSearchService:
embedding = doc_chunk[0]["embedding"] embedding = doc_chunk[0]["embedding"]
await self.title_index.update_title_data(self.page_id, self.title_info.title = self.title
self.title, self.title_info.rev_id = self.page_info["lastrevid"]
self.page_info["lastrevid"], self.title_info.latest_rev_id = self.page_info["lastrevid"]
self.collection_id, self.title_info.collection_id = self.collection_id
embedding) self.title_info.embedding = embedding
else: else:
await self.title_index.update_rev_id(self.title, self.page_info["lastrevid"]) self.title_info.rev_id = self.page_info["lastrevid"]
self.title_info.latest_rev_id = self.page_info["lastrevid"]
await self.title_index_helper.update(self.title_info)
if (self.collection_info is None or if (self.collection_info is None or
(self.base_title == self.collection_info.title and self.page_id != self.collection_info.page_id)): (self.base_title == self.collection_info.title and self.page_id != self.collection_info.page_id)):
await self.title_collection.set_page_id(self.base_title, self.page_id) await self.title_collection_helper.set_page_id(self.base_title, self.page_id)
return total_token_usage return total_token_usage

@ -2,3 +2,6 @@ import sys
import pathlib import pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent)) sys.path.append(str(pathlib.Path(__file__).parent.parent))
import config
config.DEBUG = True

@ -0,0 +1,23 @@
import asyncio
import base
from sqlalchemy import select
from api.model.embedding_search.title_index import TitleIndexModel
import local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService
async def main():
dbs = await DatabaseService.create()
async with dbs.create_session() as session:
stmt = select(TitleIndexModel).where(TitleIndexModel.title == "代号:曙光的世界/黄昏的阿瓦隆")
res = await session.scalar(stmt)
print(res.__dict__)
await asyncio.sleep(0.5)
await local.noawait.end()
if __name__ == '__main__':
local.loop.run_until_complete(main())
Loading…
Cancel
Save