将TitleIndexHelper改为ORM调用,修复编辑时报错

master
落雨楓 2 years ago
parent 003a9a7948
commit 99e23f5281

@ -194,8 +194,8 @@ class ChatComplete:
# Remove message after selected message
split_message_pos = None
for i in range(0, len(new_chunk.message_data)):
msg_data = new_chunk.message_data[i]
if msg_data["id"] == msg_id:
msg_data: dict = new_chunk.message_data[i]
if msg_data.get("id") == msg_id:
split_message_pos = i
break

@ -1,14 +1,15 @@
from __future__ import annotations
import hashlib
import asyncpg
import numpy as np
from pgvector.sqlalchemy import Vector
from pgvector.asyncpg import register_vector
import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred
from sqlalchemy.ext.asyncio import AsyncEngine
import config
from api.model.base import BaseModel
from api.model.base import BaseHelper, BaseModel
from service.database import DatabaseService
class TitleIndexModel(BaseModel):
@ -19,14 +20,15 @@ class TitleIndexModel(BaseModel):
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)))
embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding,
postgresql_using='ivfflat',
postgresql_ops={'embedding': 'vector_cosine_ops'})
class TitleIndexHelper:
class TitleIndexHelper(BaseHelper):
__tablename__ = "embedding_search_title_index"
columns = [
@ -36,12 +38,12 @@ class TitleIndexHelper:
"page_id",
"collection_id",
"rev_id",
"latest_rev_id"
"embedding",
]
def __init__(self, dbs: DatabaseService):
self.dbs = dbs
self.initialized = False
super().__init__(dbs)
async def __aenter__(self):
if not self.initialized:
@ -50,12 +52,11 @@ class TitleIndexHelper:
await register_vector(self.dbi)
self.initialized = True
return self
await super().__aenter__()
async def __aexit__(self, exc_type, exc, tb):
await self.dbpool.__aexit__(exc_type, exc, tb)
await super().__aexit__(exc_type, exc, tb)
def get_columns(self, exclude=[]):
if len(exclude) == 0:
@ -66,17 +67,14 @@ class TitleIndexHelper:
"""
Add a title to the index
"""
async def add(self, title: str, page_id: int, rev_id: int, collection_id: int, embedding: np.ndarray):
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
ret = await self.dbi.fetchrow("SELECT * FROM embedding_search_title_index WHERE sha1 = $1", title_sha1)
if ret is None:
new_id = await self.dbi.fetchval("""INSERT INTO embedding_search_title_index
(sha1, title, page_id, rev_id, collection_id, embedding)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id""",
title_sha1, title, page_id, rev_id, collection_id, embedding, column=0)
return new_id
async def add(self, obj: TitleIndexModel):
obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest()
if await self.find_by_sha1(obj.sha1) is None:
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj
return False
@ -85,25 +83,15 @@ class TitleIndexHelper:
"""
async def remove(self, title: str):
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
await self.dbi.execute("DELETE FROM embedding_search_title_index WHERE sha1 = $1", title_sha1)
stmt = sqlalchemy.delete(TitleIndexModel).where(TitleIndexModel.sha1 == title_sha1)
await self.session.execute(stmt)
"""
Update the indexed revision id of a title
"""
async def update_rev_id(self, page_id: int, rev_id: int):
await self.dbi.execute("UPDATE embedding_search_title_index SET rev_id = $1 WHERE page_id = $2", rev_id, page_id)
"""
Update title data
"""
async def update_title_data(self, page_id: int, title: str, rev_id: int, collection_id: int, embedding: np.ndarray):
if collection_page_id is None:
collection_page_id = page_id
await self.dbi.execute("""UPDATE embedding_search_title_index
SET title = $1, rev_id = $2, collection_id = $3, embedding = $4
WHERE page_id = $5""",
title, rev_id, collection_id, embedding, page_id)
async def update(self, obj: TitleIndexModel):
obj.sha1 = hashlib.sha1(obj.title.encode("utf-8")).hexdigest()
await self.session.merge(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj
"""
Search for titles by consine similary
@ -120,43 +108,18 @@ class TitleIndexHelper:
"""
Find a title in the index
"""
async def find_by_title(self, title: str, with_embedding=False):
title_sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
if with_embedding:
columns = self.get_columns()
else:
columns = self.get_columns(exclude=["embedding"])
ret = await self.dbi.fetchrow(
"SELECT %s FROM embedding_search_title_index WHERE sha1 = $1" % columns,
title_sha1
)
return ret
async def find_by_page_id(self, page_id: int, with_embedding=False):
if with_embedding:
columns = self.get_columns()
else:
columns = self.get_columns(exclude=["embedding"])
ret = await self.dbi.fetchrow(
"SELECT %s FROM embedding_search_title_index WHERE page_id = $1" % columns,
page_id
)
return ret
async def find_by_collection_id(self, collection_id: int, with_embedding=False):
if with_embedding:
columns = self.get_columns()
else:
columns = self.get_columns(exclude=["embedding"])
ret = await self.dbi.fetch(
"SELECT %s FROM embedding_search_title_index WHERE collection_id = $1" % columns,
collection_id
)
return ret
async def find_by_title(self, title: str):
sha1 = hashlib.sha1(title.encode("utf-8")).hexdigest()
return await self.find_by_sha1(sha1)
async def find_by_sha1(self, sha1: str):
stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.sha1 == sha1)
return await self.session.scalar(stmt)
async def find_by_page_id(self, page_id: int):
stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.page_id == page_id)
return await self.session.scalar(stmt)
async def find_list_by_collection_id(self, collection_id: int):
stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.collection_id == collection_id)
return await self.session.scalars(stmt)

@ -157,11 +157,11 @@ class ChatCompleteService:
edit_message_pos = None
old_tokens = 0
for i in range(0, len(self.conversation_chunk.message_data)):
msg_data = self.conversation_chunk.message_data[i]
if msg_data["id"] == edit_msg_id:
msg_data: dict = self.conversation_chunk.message_data[i]
if msg_data.get("id") == edit_msg_id:
edit_message_pos = i
break
if "tokens" in msg_data and msg_data["tokens"]:
if "tokens" in msg_data and msg_data["tokens"] is not None:
old_tokens += msg_data["tokens"]
if edit_message_pos:

@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Optional, TypedDict
from api.model.embedding_search.title_collection import TitleCollectionHelper, TitleCollectionModel
from api.model.embedding_search.title_index import TitleIndexHelper
from api.model.embedding_search.title_index import TitleIndexHelper, TitleIndexModel
from api.model.embedding_search.page_index import PageIndexHelper
from service.database import DatabaseService
from service.mediawiki_api import MediaWikiApi
@ -26,8 +26,8 @@ class EmbeddingSearchService:
self.title = title
self.base_title = title.split("/")[0]
self.title_index = TitleIndexHelper(dbs)
self.title_collection = TitleCollectionHelper(dbs)
self.title_index_helper = TitleIndexHelper(dbs)
self.title_collection_helper = TitleCollectionHelper(dbs)
self.page_index: PageIndexHelper = None
self.tiktoken: TikTokenService = None
@ -35,11 +35,11 @@ class EmbeddingSearchService:
self.mwapi = MediaWikiApi.create()
self.openai_api = OpenAIApi.create()
self.page_id: int = None
self.collection_id: int = None
self.page_id: Optional[int] = None
self.collection_id: Optional[int] = None
self.title_info: dict = None
self.collection_info: TitleCollectionModel = None
self.title_info: Optional[TitleIndexModel] = None
self.collection_info: Optional[TitleCollectionModel] = None
self.page_info: dict = None
self.unindexed_docs: list = None
@ -47,13 +47,13 @@ class EmbeddingSearchService:
async def __aenter__(self):
self.tiktoken = await TikTokenService.create()
await self.title_index.__aenter__()
await self.title_collection.__aenter__()
await self.title_index_helper.__aenter__()
await self.title_collection_helper.__aenter__()
self.title_info = await self.title_index.find_by_title(self.title)
self.title_info = await self.title_index_helper.find_by_title(self.title)
if self.title_info is not None:
self.page_id = self.title_info["page_id"]
self.collection_id = self.title_info["collection_id"]
self.page_id = self.title_info.page_id
self.collection_id = self.title_info.collection_id
self.page_index = PageIndexHelper(
self.dbs, self.collection_id, self.page_id)
@ -62,8 +62,8 @@ class EmbeddingSearchService:
return self
async def __aexit__(self, exc_type, exc, tb):
await self.title_index.__aexit__(exc_type, exc, tb)
await self.title_collection.__aexit__(exc_type, exc, tb)
await self.title_index_helper.__aexit__(exc_type, exc, tb)
await self.title_collection_helper.__aexit__(exc_type, exc, tb)
if self.page_index is not None:
await self.page_index.__aexit__(exc_type, exc, tb)
@ -82,7 +82,7 @@ class EmbeddingSearchService:
await self.load_page_info()
if (self.title_info is not None and await self.page_index_exists() and
self.title_info["title"] == self.page_info["title"] and self.title_info["rev_id"] == self.page_info["lastrevid"]):
self.title_info.title == self.page_info["title"] and self.title_info.rev_id == self.page_info["lastrevid"]):
# Not changed
return False
@ -101,9 +101,9 @@ class EmbeddingSearchService:
raise EmbeddingRunningException("Page index is running now")
# Create collection
self.collection_info = await self.title_collection.find_by_title(self.base_title)
self.collection_info = await self.title_collection_helper.find_by_title(self.base_title)
if self.collection_info is None:
self.collection_id = await self.title_collection.add(self.base_title)
self.collection_id = await self.title_collection_helper.add(self.base_title)
if self.collection_id is None:
raise Exception("Failed to create title collection")
else:
@ -191,11 +191,17 @@ class EmbeddingSearchService:
embedding = doc_chunk[0]["embedding"]
await self.title_index.add(self.page_info["title"],
self.page_id,
self.page_info["lastrevid"],
self.collection_id,
embedding)
self.title_info = TitleIndexModel(
title=self.page_info["title"],
page_id=self.page_id,
rev_id=self.page_info["lastrevid"],
latest_rev_id=self.page_info["lastrevid"],
collection_id=self.collection_id,
embedding=embedding
)
res = await self.title_index_helper.add(self.title_info)
if res:
self.title_info = res
else:
if self.title != self.page_info["title"]:
self.title = self.page_info["title"]
@ -206,17 +212,20 @@ class EmbeddingSearchService:
embedding = doc_chunk[0]["embedding"]
await self.title_index.update_title_data(self.page_id,
self.title,
self.page_info["lastrevid"],
self.collection_id,
embedding)
self.title_info.title = self.title
self.title_info.rev_id = self.page_info["lastrevid"]
self.title_info.latest_rev_id = self.page_info["lastrevid"]
self.title_info.collection_id = self.collection_id
self.title_info.embedding = embedding
else:
await self.title_index.update_rev_id(self.title, self.page_info["lastrevid"])
self.title_info.rev_id = self.page_info["lastrevid"]
self.title_info.latest_rev_id = self.page_info["lastrevid"]
await self.title_index_helper.update(self.title_info)
if (self.collection_info is None or
(self.base_title == self.collection_info.title and self.page_id != self.collection_info.page_id)):
await self.title_collection.set_page_id(self.base_title, self.page_id)
await self.title_collection_helper.set_page_id(self.base_title, self.page_id)
return total_token_usage

@ -1,4 +1,7 @@
import sys
import pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))
sys.path.append(str(pathlib.Path(__file__).parent.parent))
import config
config.DEBUG = True

@ -0,0 +1,23 @@
import asyncio
import base
from sqlalchemy import select
from api.model.embedding_search.title_index import TitleIndexModel
import local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService
async def main():
dbs = await DatabaseService.create()
async with dbs.create_session() as session:
stmt = select(TitleIndexModel).where(TitleIndexModel.title == "代号:曙光的世界/黄昏的阿瓦隆")
res = await session.scalar(stmt)
print(res.__dict__)
await asyncio.sleep(0.5)
await local.noawait.end()
if __name__ == '__main__':
local.loop.run_until_complete(main())
Loading…
Cancel
Save