增加在Collection中提问的功能

master
落雨楓 2 years ago
parent 7b4c70147b
commit 4aa627a23e

@ -1,6 +1,8 @@
import sys import sys
import traceback import traceback
from aiohttp import web from aiohttp import web
from api.model.embedding_search.title_collection import TitleCollectionHelper
from api.model.embedding_search.title_index import TitleIndexHelper
from service.database import DatabaseService from service.database import DatabaseService
from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
@ -14,9 +16,15 @@ class EmbeddingSearch:
"title": { "title": {
"required": True, "required": True,
}, },
"collection": {
"required": False,
"type": bool,
"default": False
}
}) })
page_title = params.get('title') page_title = params.get("title")
is_collection = params.get("collection")
mwapi = MediaWikiApi.create() mwapi = MediaWikiApi.create()
db = await DatabaseService.create(request.app) db = await DatabaseService.create(request.app)
@ -27,7 +35,23 @@ class EmbeddingSearch:
try: try:
transatcion_id = None transatcion_id = None
async with EmbeddingSearchService(db, page_title) as embedding_search: title_list = [page_title]
if is_collection:
# Get collection titles
async with TitleCollectionHelper(db) as title_collection, TitleIndexHelper(db) as title_index_helper:
title_collection = await title_collection.find_by_title(page_title)
if title_collection is not None:
need_update_pages = await title_index_helper.get_need_update_index_list(title_collection.id)
title_list = []
for page_info in need_update_pages:
title_list.append(page_info.title)
page_count = len(title_list)
page_current = 0
index_updated = False
for one_title in title_list:
page_current += 1
async with EmbeddingSearchService(db, one_title) as embedding_search:
if await embedding_search.should_update_page_index(): if await embedding_search.should_update_page_index():
if request.get("caller") == "user": if request.get("caller") == "user":
user_id = request.get("user") user_id = request.get("user")
@ -37,37 +61,34 @@ class EmbeddingSearch:
await embedding_search.prepare_update_index() await embedding_search.prepare_update_index()
async def on_progress(current, total): async def on_progress(current, total):
await ws.send_json({ await ws.send_json({
'event': 'progress', "event": "progress",
'current': current, "current": current,
'total': total "total": total,
"current_page": page_current,
"total_page": page_count,
}) })
token_usage = await embedding_search.update_page_index(on_progress) token_usage = await embedding_search.update_page_index(on_progress)
await ws.send_json({ index_updated = True
'event': 'done',
'status': 1,
'index_updated': True
})
if transatcion_id: if transatcion_id:
await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage) await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
else:
await ws.send_json({ await ws.send_json({
'event': 'done', "event": "done",
'status': 1, "status": 1,
'index_updated': False "index_updated": index_updated,
}) })
except MediaWikiPageNotFoundException: except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title error_msg = "Page \"%s\" not found." % page_title
await ws.send_json({ await ws.send_json({
'event': 'error', "event": "error",
'status': -2, "status": -2,
'message': error_msg, "message": error_msg,
'error': { "error": {
'code': 'page_not_found', "code": "page_not_found",
'title': page_title, "title": page_title,
}, },
}) })
if transatcion_id: if transatcion_id:
@ -77,12 +98,12 @@ class EmbeddingSearch:
print(error_msg, file=sys.stderr) print(error_msg, file=sys.stderr)
traceback.print_exc() traceback.print_exc()
await ws.send_json({ await ws.send_json({
'event': 'error', "event": "error",
'status': -3, "status": -3,
'message': error_msg, "message": error_msg,
'error': { "error": {
'code': e.code, "code": e.code,
'info': e.info, "info": e.info,
}, },
}) })
if transatcion_id: if transatcion_id:
@ -90,11 +111,11 @@ class EmbeddingSearch:
except EmbeddingRunningException: except EmbeddingRunningException:
error_msg = "Page index is running now" error_msg = "Page index is running now"
await ws.send_json({ await ws.send_json({
'event': 'error', "event": "error",
'status': -4, "status": -4,
'message': error_msg, "message": error_msg,
'error': { "error": {
'code': 'page_index_running', "code": "page_index_running",
}, },
}) })
if transatcion_id: if transatcion_id:
@ -104,11 +125,11 @@ class EmbeddingSearch:
print(error_msg, file=sys.stderr) print(error_msg, file=sys.stderr)
traceback.print_exc() traceback.print_exc()
await ws.send_json({ await ws.send_json({
'event': 'error', "event": "error",
'status': -1, "status": -1,
'message': error_msg, "message": error_msg,
'error': { "error": {
'code': 'internal_server_error', "code": "internal_server_error",
} }
}) })
if transatcion_id: if transatcion_id:
@ -133,9 +154,15 @@ class EmbeddingSearch:
if transatcion_id: if transatcion_id:
result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage) result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
return await utils.web.api_response(1, {"data_indexed": True}) return await utils.web.api_response(1, {
"data_indexed": True,
"title": embedding_search.title
})
else: else:
return await utils.web.api_response(1, {"data_indexed": False}) return await utils.web.api_response(1, {
"data_indexed": False,
"title": embedding_search.title
})
except MediaWikiPageNotFoundException: except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title error_msg = "Page \"%s\" not found." % page_title
if transatcion_id: if transatcion_id:
@ -254,3 +281,42 @@ class EmbeddingSearch:
"message": error_msg "message": error_msg
}, request=request, http_status=500) }, request=request, http_status=500)
return await utils.web.api_response(1, data={"results": results, "token_usage": token_usage}, request=request) return await utils.web.api_response(1, data={"results": results, "token_usage": token_usage}, request=request)
@staticmethod
@utils.web.token_auth
async def sys_update_title_info(request: web.Request):
params = await utils.web.get_param(request, {
"title": {
"required": True,
},
})
if request.get("caller") == "user":
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "This api is only for system caller."
}, request=request, http_status=403)
page_title = params.get("title")
db = await DatabaseService.create(request.app)
async with EmbeddingSearchService(db, page_title) as embedding_search:
try:
await embedding_search.update_title_index(True)
except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title
return await utils.web.api_response(-2, error={
"code": "page-not-found",
"title": page_title,
"message": error_msg
}, request=request, http_status=404)
except Exception as err:
err_msg = str(err)
return await utils.web.api_response(-1, error={
"code": "internal-server-error",
"message": err_msg
}, request=request, http_status=500)
return await utils.web.api_response(1, request=request)

@ -129,6 +129,7 @@ class ChatCompleteTask:
async def end(self): async def end(self):
await self.chat_complete_service.__aexit__(None, None, None) await self.chat_complete_service.__aexit__(None, None, None)
if self.task_id in chat_complete_tasks:
del chat_complete_tasks[self.task_id] del chat_complete_tasks[self.task_id]
self.is_finished = True self.is_finished = True
self.finished_time = time.time() self.finished_time = time.time()

@ -2,10 +2,11 @@ from __future__ import annotations
from typing import TypeVar from typing import TypeVar
import sqlalchemy import sqlalchemy
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.ext.asyncio import AsyncAttrs
from service.database import DatabaseService from service.database import DatabaseService
class BaseModel(DeclarativeBase): class BaseModel(AsyncAttrs, DeclarativeBase):
pass pass
class BaseHelper: class BaseHelper:

@ -21,7 +21,9 @@ page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}
class AbstractPageIndexModel(BaseModel): class AbstractPageIndexModel(BaseModel):
__abstract__ = True __abstract__ = True
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(
sqlalchemy.Integer, primary_key=True, autoincrement=True
)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True) sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)) embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
@ -36,34 +38,37 @@ def create_page_index_model(collection_id: int):
if collection_id in page_index_model_list: if collection_id in page_index_model_list:
return page_index_model_list[collection_id] return page_index_model_list[collection_id]
else: else:
class _PageIndexModel(AbstractPageIndexModel): class _PageIndexModel(AbstractPageIndexModel):
__tablename__ = "embedding_search_page_index_%s" % str(collection_id) __tablename__ = "embedding_search_page_index_%s" % str(collection_id)
embedding_index = sqlalchemy.Index(__tablename__ + "_embedding_idx", AbstractPageIndexModel.embedding, embedding_index = sqlalchemy.Index(
postgresql_using='ivfflat', __tablename__ + "_embedding_idx",
postgresql_ops={'embedding': 'vector_cosine_ops'}) AbstractPageIndexModel.embedding,
postgresql_using="ivfflat",
postgresql_ops={"embedding": "vector_cosine_ops"},
)
page_index_model_list[collection_id] = _PageIndexModel page_index_model_list[collection_id] = _PageIndexModel
return _PageIndexModel return _PageIndexModel
class PageIndexHelper: class PageIndexHelper:
columns = [ columns = [
"id", "id",
"page_id" "page_id" "sha1",
"sha1",
"text", "text",
"text_len", "text_len",
"markdown", "markdown",
"markdown_len", "markdown_len",
"embedding", "embedding",
"temp_doc_session_id" "temp_doc_session_id",
] ]
def __init__(self, dbs: DatabaseService, collection_id: int, page_id: Optional[int]): def __init__(self, dbs: DatabaseService, collection_id: int):
self.dbs = dbs self.dbs = dbs
self.collection_id = collection_id self.collection_id = collection_id
self.page_id = page_id if page_id is not None else -1
self.table_name = "embedding_search_page_index_%s" % str(collection_id) self.table_name = "embedding_search_page_index_%s" % str(collection_id)
self.initialized = False self.initialized = False
self.table_initialized = False self.table_initialized = False
@ -71,6 +76,7 @@ class PageIndexHelper:
""" """
Initialize table Initialize table
""" """
async def __aenter__(self): async def __aenter__(self):
if self.initialized: if self.initialized:
return return
@ -95,12 +101,16 @@ class PageIndexHelper:
await self.session.__aexit__(exc_type, exc, tb) await self.session.__aexit__(exc_type, exc, tb)
async def table_exists(self): async def table_exists(self):
exists = await self.dbi.fetchval("""SELECT EXISTS ( exists = await self.dbi.fetchval(
"""SELECT EXISTS (
SELECT 1 SELECT 1
FROM information_schema.tables FROM information_schema.tables
WHERE table_schema = 'public' WHERE table_schema = 'public'
AND table_name = $1 AND table_name = $1
);""", self.table_name, column=0) );""",
self.table_name,
column=0,
)
return bool(exists) return bool(exists)
@ -113,22 +123,6 @@ class PageIndexHelper:
async with self.dbs.engine.begin() as conn: async with self.dbs.engine.begin() as conn:
await conn.run_sync(self.orm.__table__.create) await conn.run_sync(self.orm.__table__.create)
# await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ (
# id SERIAL PRIMARY KEY,
# page_id INTEGER NOT NULL,
# sha1 VARCHAR(40) NOT NULL,
# text TEXT NOT NULL,
# text_len INTEGER NOT NULL,
# embedding VECTOR(%d) NOT NULL,
# markdown TEXT NULL,
# markdown_len INTEGER NULL,
# temp_doc_session_id INTEGER NULL
# );
# CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id);
# CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1);
# CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id);
# """ % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name))
self.table_initialized = True self.table_initialized = True
async def create_embedding_index(self): async def create_embedding_index(self):
@ -142,7 +136,9 @@ class PageIndexHelper:
sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest() sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest()
item["sha1"] = sha1 item["sha1"] = sha1
async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False): async def get_indexed_sha1(
self, page_id: int, with_temporary: bool = True, in_collection: bool = False
):
indexed_sha1_list = [] indexed_sha1_list = []
stmt = select(self.orm).column(self.orm.sha1) stmt = select(self.orm).column(self.orm.sha1)
@ -151,7 +147,7 @@ class PageIndexHelper:
stmt = stmt.where(self.orm.temp_doc_session_id == None) stmt = stmt.where(self.orm.temp_doc_session_id == None)
if not in_collection: if not in_collection:
stmt = stmt.where(self.orm.page_id == self.page_id) stmt = stmt.where(self.orm.page_id == page_id)
ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt) ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt)
@ -160,8 +156,10 @@ class PageIndexHelper:
return indexed_sha1_list return indexed_sha1_list
async def get_unindexed_doc(self, doc: list, with_temporary: bool = True): async def get_unindexed_doc(
indexed_sha1_list = await self.get_indexed_sha1(with_temporary) self, doc: list, page_id: int, with_temporary: bool = True
):
indexed_sha1_list = await self.get_indexed_sha1(page_id, with_temporary)
self.sha1_doc(doc) self.sha1_doc(doc)
should_index = [] should_index = []
@ -171,10 +169,10 @@ class PageIndexHelper:
return should_index return should_index
async def remove_outdated_doc(self, doc: list): async def remove_outdated_doc(self, doc: list, page_id: int):
await self.clear_temp() await self.clear_temp(page_id=page_id)
indexed_sha1_list = await self.get_indexed_sha1(False) indexed_sha1_list = await self.get_indexed_sha1(page_id, False)
self.sha1_doc(doc) self.sha1_doc(doc)
doc_sha1_list = [item["sha1"] for item in doc] doc_sha1_list = [item["sha1"] for item in doc]
@ -185,17 +183,24 @@ class PageIndexHelper:
should_remove.append(sha1) should_remove.append(sha1)
if len(should_remove) > 0: if len(should_remove) > 0:
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)" % (self.table_name), await self.dbi.execute(
self.page_id, should_remove) "DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)"
% (self.table_name),
async def index_doc(self, doc: list): page_id,
should_remove,
)
async def index_doc(self, doc: list, page_id: int):
need_create_index = False need_create_index = False
indexed_persist_sha1_list = [] indexed_persist_sha1_list = []
indexed_temp_sha1_list = [] indexed_temp_sha1_list = []
ret = await self.dbi.fetch("SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1" % (self.table_name), ret = await self.dbi.fetch(
self.page_id) "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1"
% (self.table_name),
page_id,
)
for row in ret: for row in ret:
if row[1]: if row[1]:
indexed_temp_sha1_list.append(row[0]) indexed_temp_sha1_list.append(row[0])
@ -226,13 +231,30 @@ class PageIndexHelper:
should_remove.append(sha1) should_remove.append(sha1)
if len(should_index) > 0: if len(should_index) > 0:
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) await self.dbi.executemany(
VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name), """INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"]) for item in should_index]) VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);"""
% (self.table_name),
[
(
item["sha1"],
page_id,
item["text"],
len(item["text"]),
item["markdown"],
len(item["markdown"]),
item["embedding"],
)
for item in should_index
],
)
if len(should_persist) > 0: if len(should_persist) > 0:
await self.dbi.executemany("UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name), await self.dbi.executemany(
[(self.page_id, sha1) for sha1 in should_persist]) "UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2"
% (self.table_name),
[(page_id, sha1) for sha1 in should_persist],
)
if need_create_index: if need_create_index:
await self.create_embedding_index() await self.create_embedding_index()
@ -240,14 +262,17 @@ class PageIndexHelper:
""" """
Add temporary document to the index Add temporary document to the index
""" """
async def index_temp_doc(self, doc: list, temp_doc_session_id: int): async def index_temp_doc(self, doc: list, temp_doc_session_id: int):
indexed_sha1_list = [] indexed_sha1_list = []
indexed_temp_sha1_list = [] indexed_temp_sha1_list = []
doc_sha1_list = [] doc_sha1_list = []
sql = "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1 AND (temp_doc_session_id = $2 OR temp_doc_session_id IS NULL)" % ( sql = (
self.table_name) "SELECT sha1, temp_doc_session_id FROM %s WHERE (temp_doc_session_id = $1 OR temp_doc_session_id IS NULL)"
ret = await self.dbi.fetch(sql, self.page_id, temp_doc_session_id) % (self.table_name)
)
ret = await self.dbi.fetch(sql, temp_doc_session_id)
for row in ret: for row in ret:
indexed_sha1_list.append(row[0]) indexed_sha1_list.append(row[0])
if row[1]: if row[1]:
@ -269,41 +294,79 @@ class PageIndexHelper:
should_remove.append(sha1) should_remove.append(sha1)
if len(should_index) > 0: if len(should_index) > 0:
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id) await self.dbi.executemany(
VALUES ($1, $2, $3, $4, $5, $6, $7, $8);""" % (self.table_name), """INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"], temp_doc_session_id) for item in should_index]) VALUES ($1, $2, $3, $4, $5, $6, $7, $8);"""
% (self.table_name),
[
(
item["sha1"],
item["page_id"],
item["text"],
len(item["text"]),
item["markdown"],
len(item["markdown"]),
item["embedding"],
temp_doc_session_id,
)
for item in should_index
],
)
if len(should_remove) > 0: if len(should_remove) > 0:
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND temp_doc_session_id = $2 AND sha1 = ANY($3)" % (self.table_name), await self.dbi.execute(
self.page_id, temp_doc_session_id, should_remove) "DELETE FROM %s WHERE temp_doc_session_id = $1 AND sha1 = ANY($2)"
% (self.table_name),
temp_doc_session_id,
should_remove,
)
""" """
Search for text by consine similary Search for text by consine similary
""" """
async def search_text_embedding(self, embedding: np.ndarray, in_collection: bool = False, limit: int = 10):
async def search_text_embedding(
self,
embedding: np.ndarray,
in_collection: bool = False,
limit: int = 10,
page_id: Optional[int] = None,
):
if in_collection: if in_collection:
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance return await self.dbi.fetch(
"""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
FROM %s FROM %s
ORDER BY distance ASC ORDER BY distance ASC
LIMIT %d""" % (self.table_name, limit), embedding) LIMIT %d"""
% (self.table_name, limit),
embedding,
)
else: else:
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance return await self.dbi.fetch(
"""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
FROM %s FROM %s
WHERE page_id = $2 WHERE page_id = $2
ORDER BY distance ASC ORDER BY distance ASC
LIMIT %d""" % (self.table_name, limit), embedding, self.page_id) LIMIT %d"""
% (self.table_name, limit),
embedding,
page_id,
)
""" """
Clear temporary index Clear temporary index
""" """
async def clear_temp(self, in_collection: bool = False, temp_doc_session_id: int = None):
async def clear_temp(
self, in_collection: bool = False, temp_doc_session_id: int = None, page_id: Optional[int] = None
):
sql = "DELETE FROM %s" % (self.table_name) sql = "DELETE FROM %s" % (self.table_name)
where = [] where = []
params = [] params = []
if not in_collection: if not in_collection and page_id:
params.append(self.page_id) params.append(page_id)
where.append("page_id = $%d" % len(params)) where.append("page_id = $%d" % len(params))
if temp_doc_session_id: if temp_doc_session_id:
@ -316,3 +379,8 @@ class PageIndexHelper:
sql += " WHERE " + (" AND ".join(where)) sql += " WHERE " + (" AND ".join(where))
await self.dbi.execute(sql, *params) await self.dbi.execute(sql, *params)
async def remove_by_page_id(self, page_id: int):
stmt = delete(self.orm).where(self.orm.page_id == page_id)
await self.session.execute(stmt)
await self.session.commit()

@ -23,11 +23,11 @@ class TitleCollectionHelper(BaseHelper):
self.session.add(obj) self.session.add(obj)
await self.session.commit() await self.session.commit()
await self.session.refresh(obj) await self.session.refresh(obj)
return obj.id return obj
return False return None
async def set_page_id(self, title: str, page_id: Optional[str] = None): async def set_main_page_id(self, title: str, page_id: Optional[str] = None):
stmt = update(TitleCollectionModel).where(TitleCollectionModel.title == title).values(page_id=page_id) stmt = update(TitleCollectionModel).where(TitleCollectionModel.title == title).values(page_id=page_id)
await self.session.execute(stmt) await self.session.execute(stmt)
await self.session.commit() await self.session.commit()
@ -37,6 +37,10 @@ class TitleCollectionHelper(BaseHelper):
await self.session.execute(stmt) await self.session.execute(stmt)
await self.session.commit() await self.session.commit()
async def find_by_id(self, id: int):
stmt = select(TitleCollectionModel).where(TitleCollectionModel.id == id)
return await self.session.scalar(stmt)
async def find_by_title(self, title: str): async def find_by_title(self, title: str):
stmt = select(TitleCollectionModel).where(TitleCollectionModel.title == title) stmt = select(TitleCollectionModel).where(TitleCollectionModel.title == title)
return await self.session.scalar(stmt) return await self.session.scalar(stmt)

@ -5,7 +5,7 @@ import numpy as np
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from pgvector.asyncpg import register_vector from pgvector.asyncpg import register_vector
import sqlalchemy import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer
from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncEngine
import config import config
@ -20,9 +20,9 @@ class TitleIndexModel(BaseModel):
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True) title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True) collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer) indexed_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer) latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))) embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE), nullable=True))
embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding, embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding,
postgresql_using='ivfflat', postgresql_using='ivfflat',
@ -37,7 +37,7 @@ class TitleIndexHelper(BaseHelper):
"title", "title",
"page_id", "page_id",
"collection_id", "collection_id",
"rev_id", "indexed_rev_id",
"latest_rev_id" "latest_rev_id"
"embedding", "embedding",
] ]
@ -52,11 +52,11 @@ class TitleIndexHelper(BaseHelper):
await register_vector(self.dbi) await register_vector(self.dbi)
await super().__aenter__() return await super().__aenter__()
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
await self.dbpool.__aexit__(exc_type, exc, tb) await self.dbpool.__aexit__(exc_type, exc, tb)
await super().__aexit__(exc_type, exc, tb) return await super().__aexit__(exc_type, exc, tb)
def get_columns(self, exclude=[]): def get_columns(self, exclude=[]):
if len(exclude) == 0: if len(exclude) == 0:
@ -93,6 +93,10 @@ class TitleIndexHelper(BaseHelper):
await self.session.refresh(obj) await self.session.refresh(obj)
return obj return obj
async def refresh(self, obj: TitleIndexModel):
await self.session.refresh(obj)
return obj
""" """
Search for titles by consine similary Search for titles by consine similary
""" """
@ -120,6 +124,15 @@ class TitleIndexHelper(BaseHelper):
stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.page_id == page_id) stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.page_id == page_id)
return await self.session.scalar(stmt) return await self.session.scalar(stmt)
async def find_list_by_collection_id(self, collection_id: int): async def find_list_by_collection_id(self, collection_id: int) -> sqlalchemy.ScalarResult[TitleIndexModel]:
stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.collection_id == collection_id) stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.collection_id == collection_id)
return await self.session.scalars(stmt) return await self.session.scalars(stmt)
async def get_need_update_index_list(self, collection_id: int) -> list[TitleIndexModel]:
page_list = await self.find_list_by_collection_id(collection_id)
result: list[TitleIndexModel] = []
for page_info in page_list:
if page_info.indexed_rev_id != page_info.latest_rev_id:
result.append(page_info)
return result

@ -28,6 +28,7 @@ def init(app: web.Application):
web.route('*', '/embedding_search/index_page', EmbeddingSearch.index_page), web.route('*', '/embedding_search/index_page', EmbeddingSearch.index_page),
web.route('*', '/embedding_search/search', EmbeddingSearch.search), web.route('*', '/embedding_search/search', EmbeddingSearch.search),
web.route('POST', '/sys/embedding_search/title/update', EmbeddingSearch.sys_update_title_info),
web.route('*', '/chatcomplete/conversation_chunk/list', ChatComplete.get_conversation_chunk_list), web.route('*', '/chatcomplete/conversation_chunk/list', ChatComplete.get_conversation_chunk_list),
web.route('*', '/chatcomplete/conversation_chunk/info', ChatComplete.get_conversation_chunk), web.route('*', '/chatcomplete/conversation_chunk/info', ChatComplete.get_conversation_chunk),

@ -0,0 +1,4 @@
import sys
import pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))

@ -0,0 +1,35 @@
import asyncio
import base as _
import local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi
async def main():
dbs = await DatabaseService.create()
mw_api = MediaWikiApi.create()
continue_key = None
while True:
page_res = await mw_api.get_all_pages(continue_key)
title_list = page_res["title_list"]
for page_title in title_list:
print("Indexing %s" % page_title)
async with EmbeddingSearchService(dbs, page_title) as embedding_search:
await embedding_search.update_title_index(True)
if not page_res["continue_key"]:
break
continue_key = page_res["continue_key"]
print("Done")
await local.noawait.end()
await asyncio.sleep(1)
if __name__ == '__main__':
local.loop.run_until_complete(main())

@ -10,7 +10,7 @@ pgvector==0.1.6
websockets==11.0 websockets==11.0
PyJWT==2.6.0 PyJWT==2.6.0
asyncpg-stubs==0.27.0 asyncpg-stubs==0.27.0
sqlalchemy==2.0.9 sqlalchemy==2.0.17
aiohttp-sse-client2==0.3.0 aiohttp-sse-client2==0.3.0
OpenCC==1.1.6 OpenCC==1.1.6
event-emitter-asyncio==1.0.4 event-emitter-asyncio==1.0.4

@ -196,12 +196,12 @@ class ChatCompleteService:
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk) self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk)
else: else:
# 创建新对话 # 创建新对话
title_info = self.embedding_search.title_info title_info = self.embedding_search.title_index
self.conversation_info = ConversationModel( self.conversation_info = ConversationModel(
user_id=self.user_id, user_id=self.user_id,
module="chatcomplete", module="chatcomplete",
page_id=title_info["page_id"], page_id=title_info.page_id,
rev_id=title_info["rev_id"], rev_id=title_info.latest_rev_id,
) )
self.conversation_info = await self.conversation_helper.add( self.conversation_info = await self.conversation_helper.add(
self.conversation_info, self.conversation_info,

@ -1,6 +1,11 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, TypedDict from typing import Optional, TypedDict
from api.model.embedding_search.title_collection import TitleCollectionHelper, TitleCollectionModel
import sqlalchemy
from api.model.embedding_search.title_collection import (
TitleCollectionHelper,
TitleCollectionModel,
)
from api.model.embedding_search.title_index import TitleIndexHelper, TitleIndexModel from api.model.embedding_search.title_index import TitleIndexHelper, TitleIndexModel
from api.model.embedding_search.page_index import PageIndexHelper from api.model.embedding_search.page_index import PageIndexHelper
from service.database import DatabaseService from service.database import DatabaseService
@ -9,14 +14,17 @@ from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService from service.tiktoken import TikTokenService
from utils.wiki import getWikiSentences from utils.wiki import getWikiSentences
class EmbeddingRunningException(Exception): class EmbeddingRunningException(Exception):
pass pass
class EmbeddingSearchArgs(TypedDict): class EmbeddingSearchArgs(TypedDict):
limit: Optional[int] limit: Optional[int]
in_collection: Optional[bool] in_collection: Optional[bool]
distance_limit: Optional[float] distance_limit: Optional[float]
class EmbeddingSearchService: class EmbeddingSearchService:
indexing_page_ids: list[int] = [] indexing_page_ids: list[int] = []
@ -38,7 +46,7 @@ class EmbeddingSearchService:
self.page_id: Optional[int] = None self.page_id: Optional[int] = None
self.collection_id: Optional[int] = None self.collection_id: Optional[int] = None
self.title_info: Optional[TitleIndexModel] = None self.title_index: Optional[TitleIndexModel] = None
self.collection_info: Optional[TitleCollectionModel] = None self.collection_info: Optional[TitleCollectionModel] = None
self.page_info: dict = None self.page_info: dict = None
@ -50,13 +58,18 @@ class EmbeddingSearchService:
await self.title_index_helper.__aenter__() await self.title_index_helper.__aenter__()
await self.title_collection_helper.__aenter__() await self.title_collection_helper.__aenter__()
self.title_info = await self.title_index_helper.find_by_title(self.title) self.title_index = await self.title_index_helper.find_by_title(self.title)
if self.title_info is not None: if self.title_index is None:
self.page_id = self.title_info.page_id # Title may changed, get page info from page_id
self.collection_id = self.title_info.collection_id await self.load_page_info()
self.title_index = await self.title_index_helper.find_by_page_id(
self.page_index = PageIndexHelper( self.page_info["pageid"]
self.dbs, self.collection_id, self.page_id) )
self.page_id = self.page_info["pageid"]
else:
self.page_id = self.title_index.page_id
self.collection_id = self.title_index.collection_id
self.page_index = PageIndexHelper(self.dbs, self.collection_id)
await self.page_index.__aenter__() await self.page_index.__aenter__()
return self return self
@ -78,47 +91,106 @@ class EmbeddingSearchService:
if self.page_info is None or reload: if self.page_info is None or reload:
self.page_info = await self.mwapi.get_page_info(self.title) self.page_info = await self.mwapi.get_page_info(self.title)
async def should_update_page_index(self): async def should_update_page_index(self, remote_update=False):
if not remote_update:
if self.title_index is None:
return True
return self.title_index.indexed_rev_id != self.title_index.latest_rev_id
else:
await self.load_page_info() await self.load_page_info()
if (self.title_info is not None and await self.page_index_exists() and if (
self.title_info.title == self.page_info["title"] and self.title_info.rev_id == self.page_info["lastrevid"]): self.title_index is not None
and await self.page_index_exists()
and self.title_index.indexed_rev_id == self.page_info["lastrevid"]
):
# Not changed # Not changed
return False return False
return True return True
async def prepare_update_index(self): async def update_title_index(self, remote_update=False):
# Check rev_id if not await self.should_update_page_index(remote_update):
await self.load_page_info()
if not await self.should_update_page_index():
return False return False
await self.load_page_info()
self.page_id = self.page_info["pageid"] self.page_id = self.page_info["pageid"]
if self.page_id in self.indexing_page_ids: if self.page_id in self.indexing_page_ids:
raise EmbeddingRunningException("Page index is running now") raise EmbeddingRunningException("Page index is running now")
self.title: str = self.page_info["title"]
self.base_title = self.title.split("/")[0]
# Find collection by base title
self.collection_info = await self.title_collection_helper.find_by_title(
self.base_title
)
if self.collection_info is None:
# Create collection # Create collection
self.collection_info = await self.title_collection_helper.find_by_title(self.base_title) self.collection_info = await self.title_collection_helper.add(
self.base_title
)
if self.collection_info is None: if self.collection_info is None:
self.collection_id = await self.title_collection_helper.add(self.base_title)
if self.collection_id is None:
raise Exception("Failed to create title collection") raise Exception("Failed to create title collection")
else:
self.collection_id = self.collection_info.id self.collection_id = self.collection_info.id
self.page_index = PageIndexHelper( if self.title_index == None:
self.dbs, self.collection_id, self.page_id) # Create title record
self.title_index = TitleIndexModel(
title=self.page_info["title"],
page_id=self.page_id,
indexed_rev_id=None,
latest_rev_id=self.page_info["lastrevid"],
collection_id=self.collection_id,
embedding=None,
)
self.title_index = await self.title_index_helper.add(self.title_index)
if self.title_index is None:
raise Exception("Failed to create title index")
else:
self.title_index.latest_rev_id = self.page_info["lastrevid"]
# Title changed, remove embedding
# Title sha1 will be updated by model helper
if self.title_index.title != self.page_info["title"]:
self.title_index.title = self.page_info["title"]
self.title_index.embedding = None
# Collection changed, remove old index
if self.collection_id != self.title_index.collection_id:
async with PageIndexHelper(self.dbs, self.title_index.collection_id) as old_page_index:
await self.page_index.init_table()
old_page_index.remove_by_page_id(self.page_id)
self.title_index.collection_id = self.collection_id
await self.title_index_helper.update(self.title_index)
# Update collection main page id
if (
self.title == self.collection_info.title
and self.page_id != self.collection_info.page_id
):
await self.title_collection_helper.set_main_page_id(
self.base_title, self.page_id
)
if self.page_index:
await self.page_index.__aexit__(None, None, None)
self.page_index = PageIndexHelper(self.dbs, self.collection_id)
await self.page_index.__aenter__() await self.page_index.__aenter__()
await self.page_index.init_table() await self.page_index.init_table()
async def prepare_update_index(self):
await self.update_title_index()
page_content = await self.mwapi.parse_page(self.title) page_content = await self.mwapi.parse_page(self.title)
self.sentences = getWikiSentences(page_content) self.sentences = getWikiSentences(page_content)
self.unindexed_docs = await self.page_index.get_unindexed_doc(self.sentences, with_temporary=False) self.unindexed_docs = await self.page_index.get_unindexed_doc(
self.sentences, self.page_id, with_temporary=False
)
return True return True
@ -153,8 +225,10 @@ class EmbeddingSearchService:
await on_progress(indexed_docs, len(self.unindexed_docs)) await on_progress(indexed_docs, len(self.unindexed_docs))
async def embedding_doc(doc_chunk): async def embedding_doc(doc_chunk):
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk, on_embedding_progress) (doc_chunk, token_usage) = await self.openai_api.get_embeddings(
await self.page_index.index_doc(doc_chunk) doc_chunk, on_embedding_progress
)
await self.page_index.index_doc(doc_chunk, self.page_id)
return token_usage return token_usage
@ -181,55 +255,34 @@ class EmbeddingSearchService:
if on_progress is not None: if on_progress is not None:
await on_progress(len(self.unindexed_docs), len(self.unindexed_docs)) await on_progress(len(self.unindexed_docs), len(self.unindexed_docs))
await self.page_index.remove_outdated_doc(self.sentences) await self.page_index.remove_outdated_doc(self.sentences, self.page_id)
# Update database # Update database
if self.title_info is None: # This task may take a long time, refresh model to retrieve latest data
doc_chunk = [{"text": self.title}] self.title_index = await self.title_index_helper.refresh(self.title_index)
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
total_token_usage += token_usage
embedding = doc_chunk[0]["embedding"] self.title_index.indexed_rev_id = self.page_info["lastrevid"]
self.title_info = TitleIndexModel(
title=self.page_info["title"],
page_id=self.page_id,
rev_id=self.page_info["lastrevid"],
latest_rev_id=self.page_info["lastrevid"],
collection_id=self.collection_id,
embedding=embedding
)
res = await self.title_index_helper.add(self.title_info)
if res:
self.title_info = res
else:
if self.title != self.page_info["title"]:
self.title = self.page_info["title"]
# Update title embedding
if await self.title_index.awaitable_attrs.embedding is None:
doc_chunk = [{"text": self.title}] doc_chunk = [{"text": self.title}]
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk) (doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
total_token_usage += token_usage total_token_usage += token_usage
embedding = doc_chunk[0]["embedding"] embedding = doc_chunk[0]["embedding"]
self.title_index.embedding = embedding
self.title_info.title = self.title await self.title_index_helper.update(self.title_index)
self.title_info.rev_id = self.page_info["lastrevid"]
self.title_info.latest_rev_id = self.page_info["lastrevid"]
self.title_info.collection_id = self.collection_id
self.title_info.embedding = embedding
else:
self.title_info.rev_id = self.page_info["lastrevid"]
self.title_info.latest_rev_id = self.page_info["lastrevid"]
await self.title_index_helper.update(self.title_info)
if (self.collection_info is None or
(self.base_title == self.collection_info.title and self.page_id != self.collection_info.page_id)):
await self.title_collection_helper.set_page_id(self.base_title, self.page_id)
return total_token_usage return total_token_usage
async def search(self, query: str, limit: int = 10, in_collection: bool = False, distance_limit: float = 0.6): async def search(
self,
query: str,
limit: int = 10,
in_collection: bool = False,
distance_limit: float = 0.6,
):
if self.page_index is None: if self.page_index is None:
raise Exception("Page index is not initialized") raise Exception("Page index is not initialized")
@ -240,7 +293,9 @@ class EmbeddingSearchService:
if query_embedding is None: if query_embedding is None:
return [], token_usage return [], token_usage
res = await self.page_index.search_text_embedding(query_embedding, in_collection, limit) res = await self.page_index.search_text_embedding(
query_embedding, in_collection, limit, self.page_id
)
if res: if res:
filtered = [] filtered = []
for one in res: for one in res:

@ -36,6 +36,10 @@ class MediaWikiUserNoEnoughPointsException(Exception):
def __str__(self) -> str: def __str__(self) -> str:
return self.info return self.info
class GetAllPagesResponse(TypedDict):
title_list: list[str]
continue_key: Optional[str]
class ChatCompleteGetPointUsageResponse(TypedDict): class ChatCompleteGetPointUsageResponse(TypedDict):
point_cost: int point_cost: int
@ -125,6 +129,38 @@ class MediaWikiApi:
return ret return ret
async def get_all_pages(self, continue_key: Optional[str] = None, start_title: Optional[str] = None) -> GetAllPagesResponse:
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {
"action": "query",
"format": "json",
"formatversion": "2",
"list": "allpages",
"apnamespace": 0,
"apfilterredir": "nonredirects",
"aplimit": 100,
}
if continue_key is not None:
params["apcontinue"] = continue_key
if start_title is not None:
params["apfrom"] = start_title
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
title_list = []
for page in data["query"]["allpages"]:
title_list.append(page["title"])
ret = GetAllPagesResponse(title_list=title_list, continue_key=None)
if "continue" in data and "apcontinue" in data["continue"]:
ret["continue_key"] = data["continue"]["apcontinue"]
return ret
async def is_logged_in(self): async def is_logged_in(self):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session: async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = { params = {

Loading…
Cancel
Save