增加在Collection中提问的功能

master
落雨楓 2 years ago
parent 7b4c70147b
commit 4aa627a23e

@ -1,6 +1,8 @@
import sys
import traceback
from aiohttp import web
from api.model.embedding_search.title_collection import TitleCollectionHelper
from api.model.embedding_search.title_index import TitleIndexHelper
from service.database import DatabaseService
from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
@ -14,9 +16,15 @@ class EmbeddingSearch:
"title": {
"required": True,
},
"collection": {
"required": False,
"type": bool,
"default": False
}
})
page_title = params.get('title')
page_title = params.get("title")
is_collection = params.get("collection")
mwapi = MediaWikiApi.create()
db = await DatabaseService.create(request.app)
@ -27,47 +35,60 @@ class EmbeddingSearch:
try:
transatcion_id = None
async with EmbeddingSearchService(db, page_title) as embedding_search:
if await embedding_search.should_update_page_index():
if request.get("caller") == "user":
user_id = request.get("user")
usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id")
title_list = [page_title]
if is_collection:
# Get collection titles
async with TitleCollectionHelper(db) as title_collection, TitleIndexHelper(db) as title_index_helper:
title_collection = await title_collection.find_by_title(page_title)
if title_collection is not None:
need_update_pages = await title_index_helper.get_need_update_index_list(title_collection.id)
title_list = []
for page_info in need_update_pages:
title_list.append(page_info.title)
await embedding_search.prepare_update_index()
page_count = len(title_list)
page_current = 0
index_updated = False
for one_title in title_list:
page_current += 1
async with EmbeddingSearchService(db, one_title) as embedding_search:
if await embedding_search.should_update_page_index():
if request.get("caller") == "user":
user_id = request.get("user")
usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
transatcion_id = usage_res.get("transaction_id")
async def on_progress(current, total):
await embedding_search.prepare_update_index()
await ws.send_json({
'event': 'progress',
'current': current,
'total': total
})
async def on_progress(current, total):
await ws.send_json({
"event": "progress",
"current": current,
"total": total,
"current_page": page_current,
"total_page": page_count,
})
token_usage = await embedding_search.update_page_index(on_progress)
await ws.send_json({
'event': 'done',
'status': 1,
'index_updated': True
})
token_usage = await embedding_search.update_page_index(on_progress)
index_updated = True
if transatcion_id:
await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
else:
await ws.send_json({
'event': 'done',
'status': 1,
'index_updated': False
})
if transatcion_id:
await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
await ws.send_json({
"event": "done",
"status": 1,
"index_updated": index_updated,
})
except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title
await ws.send_json({
'event': 'error',
'status': -2,
'message': error_msg,
'error': {
'code': 'page_not_found',
'title': page_title,
"event": "error",
"status": -2,
"message": error_msg,
"error": {
"code": "page_not_found",
"title": page_title,
},
})
if transatcion_id:
@ -77,12 +98,12 @@ class EmbeddingSearch:
print(error_msg, file=sys.stderr)
traceback.print_exc()
await ws.send_json({
'event': 'error',
'status': -3,
'message': error_msg,
'error': {
'code': e.code,
'info': e.info,
"event": "error",
"status": -3,
"message": error_msg,
"error": {
"code": e.code,
"info": e.info,
},
})
if transatcion_id:
@ -90,11 +111,11 @@ class EmbeddingSearch:
except EmbeddingRunningException:
error_msg = "Page index is running now"
await ws.send_json({
'event': 'error',
'status': -4,
'message': error_msg,
'error': {
'code': 'page_index_running',
"event": "error",
"status": -4,
"message": error_msg,
"error": {
"code": "page_index_running",
},
})
if transatcion_id:
@ -104,11 +125,11 @@ class EmbeddingSearch:
print(error_msg, file=sys.stderr)
traceback.print_exc()
await ws.send_json({
'event': 'error',
'status': -1,
'message': error_msg,
'error': {
'code': 'internal_server_error',
"event": "error",
"status": -1,
"message": error_msg,
"error": {
"code": "internal_server_error",
}
})
if transatcion_id:
@ -133,9 +154,15 @@ class EmbeddingSearch:
if transatcion_id:
result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
return await utils.web.api_response(1, {"data_indexed": True})
return await utils.web.api_response(1, {
"data_indexed": True,
"title": embedding_search.title
})
else:
return await utils.web.api_response(1, {"data_indexed": False})
return await utils.web.api_response(1, {
"data_indexed": False,
"title": embedding_search.title
})
except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title
if transatcion_id:
@ -253,4 +280,43 @@ class EmbeddingSearch:
"code": "internal-server-error",
"message": error_msg
}, request=request, http_status=500)
return await utils.web.api_response(1, data={"results": results, "token_usage": token_usage}, request=request)
return await utils.web.api_response(1, data={"results": results, "token_usage": token_usage}, request=request)
@staticmethod
@utils.web.token_auth
async def sys_update_title_info(request: web.Request):
params = await utils.web.get_param(request, {
"title": {
"required": True,
},
})
if request.get("caller") == "user":
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "This api is only for system caller."
}, request=request, http_status=403)
page_title = params.get("title")
db = await DatabaseService.create(request.app)
async with EmbeddingSearchService(db, page_title) as embedding_search:
try:
await embedding_search.update_title_index(True)
except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title
return await utils.web.api_response(-2, error={
"code": "page-not-found",
"title": page_title,
"message": error_msg
}, request=request, http_status=404)
except Exception as err:
err_msg = str(err)
return await utils.web.api_response(-1, error={
"code": "internal-server-error",
"message": err_msg
}, request=request, http_status=500)
return await utils.web.api_response(1, request=request)

@ -129,7 +129,8 @@ class ChatCompleteTask:
async def end(self):
await self.chat_complete_service.__aexit__(None, None, None)
del chat_complete_tasks[self.task_id]
if self.task_id in chat_complete_tasks:
del chat_complete_tasks[self.task_id]
self.is_finished = True
self.finished_time = time.time()

@ -2,10 +2,11 @@ from __future__ import annotations
from typing import TypeVar
import sqlalchemy
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.ext.asyncio import AsyncAttrs
from service.database import DatabaseService
class BaseModel(DeclarativeBase):
class BaseModel(AsyncAttrs, DeclarativeBase):
pass
class BaseHelper:

@ -21,7 +21,9 @@ page_index_model_list: dict[int, Type[AbstractPageIndexModel]] = {}
class AbstractPageIndexModel(BaseModel):
__abstract__ = True
id: Mapped[int] = mapped_column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
id: Mapped[int] = mapped_column(
sqlalchemy.Integer, primary_key=True, autoincrement=True
)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
sha1: Mapped[str] = mapped_column(sqlalchemy.String(40), index=True)
embedding: Mapped[np.ndarray] = mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE))
@ -36,34 +38,37 @@ def create_page_index_model(collection_id: int):
if collection_id in page_index_model_list:
return page_index_model_list[collection_id]
else:
class _PageIndexModel(AbstractPageIndexModel):
__tablename__ = "embedding_search_page_index_%s" % str(collection_id)
embedding_index = sqlalchemy.Index(__tablename__ + "_embedding_idx", AbstractPageIndexModel.embedding,
postgresql_using='ivfflat',
postgresql_ops={'embedding': 'vector_cosine_ops'})
embedding_index = sqlalchemy.Index(
__tablename__ + "_embedding_idx",
AbstractPageIndexModel.embedding,
postgresql_using="ivfflat",
postgresql_ops={"embedding": "vector_cosine_ops"},
)
page_index_model_list[collection_id] = _PageIndexModel
return _PageIndexModel
class PageIndexHelper:
columns = [
"id",
"page_id"
"sha1",
"page_id" "sha1",
"text",
"text_len",
"markdown",
"markdown_len",
"embedding",
"temp_doc_session_id"
"temp_doc_session_id",
]
def __init__(self, dbs: DatabaseService, collection_id: int, page_id: Optional[int]):
def __init__(self, dbs: DatabaseService, collection_id: int):
self.dbs = dbs
self.collection_id = collection_id
self.page_id = page_id if page_id is not None else -1
self.table_name = "embedding_search_page_index_%s" % str(collection_id)
self.initialized = False
self.table_initialized = False
@ -71,10 +76,11 @@ class PageIndexHelper:
"""
Initialize table
"""
async def __aenter__(self):
if self.initialized:
return
self.dbpool = self.dbs.pool.acquire()
self.dbi = await self.dbpool.__aenter__()
@ -95,12 +101,16 @@ class PageIndexHelper:
await self.session.__aexit__(exc_type, exc, tb)
async def table_exists(self):
exists = await self.dbi.fetchval("""SELECT EXISTS (
exists = await self.dbi.fetchval(
"""SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = $1
);""", self.table_name, column=0)
);""",
self.table_name,
column=0,
)
return bool(exists)
@ -113,22 +123,6 @@ class PageIndexHelper:
async with self.dbs.engine.begin() as conn:
await conn.run_sync(self.orm.__table__.create)
# await self.dbi.execute(("""CREATE TABLE IF NOT EXISTS /*_*/ (
# id SERIAL PRIMARY KEY,
# page_id INTEGER NOT NULL,
# sha1 VARCHAR(40) NOT NULL,
# text TEXT NOT NULL,
# text_len INTEGER NOT NULL,
# embedding VECTOR(%d) NOT NULL,
# markdown TEXT NULL,
# markdown_len INTEGER NULL,
# temp_doc_session_id INTEGER NULL
# );
# CREATE INDEX /*_*/_page_id_idx ON /*_*/ (page_id);
# CREATE INDEX /*_*/_sha1_idx ON /*_*/ (sha1);
# CREATE INDEX /*_*/_temp_doc_session_id_idx ON /*_*/ (temp_doc_session_id);
# """ % config.EMBEDDING_VECTOR_SIZE).replace("/*_*/", self.table_name))
self.table_initialized = True
async def create_embedding_index(self):
@ -142,7 +136,9 @@ class PageIndexHelper:
sha1 = hashlib.sha1(item["text"].encode("utf-8")).hexdigest()
item["sha1"] = sha1
async def get_indexed_sha1(self, with_temporary: bool = True, in_collection: bool = False):
async def get_indexed_sha1(
self, page_id: int, with_temporary: bool = True, in_collection: bool = False
):
indexed_sha1_list = []
stmt = select(self.orm).column(self.orm.sha1)
@ -151,7 +147,7 @@ class PageIndexHelper:
stmt = stmt.where(self.orm.temp_doc_session_id == None)
if not in_collection:
stmt = stmt.where(self.orm.page_id == self.page_id)
stmt = stmt.where(self.orm.page_id == page_id)
ret: list[AbstractPageIndexModel] = await self.session.scalars(stmt)
@ -160,8 +156,10 @@ class PageIndexHelper:
return indexed_sha1_list
async def get_unindexed_doc(self, doc: list, with_temporary: bool = True):
indexed_sha1_list = await self.get_indexed_sha1(with_temporary)
async def get_unindexed_doc(
self, doc: list, page_id: int, with_temporary: bool = True
):
indexed_sha1_list = await self.get_indexed_sha1(page_id, with_temporary)
self.sha1_doc(doc)
should_index = []
@ -171,10 +169,10 @@ class PageIndexHelper:
return should_index
async def remove_outdated_doc(self, doc: list):
await self.clear_temp()
async def remove_outdated_doc(self, doc: list, page_id: int):
await self.clear_temp(page_id=page_id)
indexed_sha1_list = await self.get_indexed_sha1(False)
indexed_sha1_list = await self.get_indexed_sha1(page_id, False)
self.sha1_doc(doc)
doc_sha1_list = [item["sha1"] for item in doc]
@ -185,17 +183,24 @@ class PageIndexHelper:
should_remove.append(sha1)
if len(should_remove) > 0:
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)" % (self.table_name),
self.page_id, should_remove)
async def index_doc(self, doc: list):
await self.dbi.execute(
"DELETE FROM %s WHERE page_id = $1 AND sha1 = ANY($2)"
% (self.table_name),
page_id,
should_remove,
)
async def index_doc(self, doc: list, page_id: int):
need_create_index = False
indexed_persist_sha1_list = []
indexed_temp_sha1_list = []
ret = await self.dbi.fetch("SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1" % (self.table_name),
self.page_id)
ret = await self.dbi.fetch(
"SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1"
% (self.table_name),
page_id,
)
for row in ret:
if row[1]:
indexed_temp_sha1_list.append(row[0])
@ -226,28 +231,48 @@ class PageIndexHelper:
should_remove.append(sha1)
if len(should_index) > 0:
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);""" % (self.table_name),
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"]) for item in should_index])
await self.dbi.executemany(
"""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, NULL);"""
% (self.table_name),
[
(
item["sha1"],
page_id,
item["text"],
len(item["text"]),
item["markdown"],
len(item["markdown"]),
item["embedding"],
)
for item in should_index
],
)
if len(should_persist) > 0:
await self.dbi.executemany("UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2" % (self.table_name),
[(self.page_id, sha1) for sha1 in should_persist])
await self.dbi.executemany(
"UPDATE %s SET temp_doc_session_id = NULL WHERE page_id = $1 AND sha1 = $2"
% (self.table_name),
[(page_id, sha1) for sha1 in should_persist],
)
if need_create_index:
await self.create_embedding_index()
"""
Add temporary document to the index
"""
async def index_temp_doc(self, doc: list, temp_doc_session_id: int):
indexed_sha1_list = []
indexed_temp_sha1_list = []
doc_sha1_list = []
sql = "SELECT sha1, temp_doc_session_id FROM %s WHERE page_id = $1 AND (temp_doc_session_id = $2 OR temp_doc_session_id IS NULL)" % (
self.table_name)
ret = await self.dbi.fetch(sql, self.page_id, temp_doc_session_id)
sql = (
"SELECT sha1, temp_doc_session_id FROM %s WHERE (temp_doc_session_id = $1 OR temp_doc_session_id IS NULL)"
% (self.table_name)
)
ret = await self.dbi.fetch(sql, temp_doc_session_id)
for row in ret:
indexed_sha1_list.append(row[0])
if row[1]:
@ -269,41 +294,79 @@ class PageIndexHelper:
should_remove.append(sha1)
if len(should_index) > 0:
await self.dbi.executemany("""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8);""" % (self.table_name),
[(item["sha1"], self.page_id, item["text"], len(item["text"]), item["markdown"], len(item["markdown"]), item["embedding"], temp_doc_session_id) for item in should_index])
await self.dbi.executemany(
"""INSERT INTO %s (sha1, page_id, text, text_len, markdown, markdown_len, embedding, temp_doc_session_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8);"""
% (self.table_name),
[
(
item["sha1"],
item["page_id"],
item["text"],
len(item["text"]),
item["markdown"],
len(item["markdown"]),
item["embedding"],
temp_doc_session_id,
)
for item in should_index
],
)
if len(should_remove) > 0:
await self.dbi.execute("DELETE FROM %s WHERE page_id = $1 AND temp_doc_session_id = $2 AND sha1 = ANY($3)" % (self.table_name),
self.page_id, temp_doc_session_id, should_remove)
await self.dbi.execute(
"DELETE FROM %s WHERE temp_doc_session_id = $1 AND sha1 = ANY($2)"
% (self.table_name),
temp_doc_session_id,
should_remove,
)
"""
Search for text by consine similary
"""
async def search_text_embedding(self, embedding: np.ndarray, in_collection: bool = False, limit: int = 10):
async def search_text_embedding(
self,
embedding: np.ndarray,
in_collection: bool = False,
limit: int = 10,
page_id: Optional[int] = None,
):
if in_collection:
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
return await self.dbi.fetch(
"""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
FROM %s
ORDER BY distance ASC
LIMIT %d""" % (self.table_name, limit), embedding)
LIMIT %d"""
% (self.table_name, limit),
embedding,
)
else:
return await self.dbi.fetch("""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
return await self.dbi.fetch(
"""SELECT id, sha1, text, text_len, markdown, markdown_len, embedding <-> $1 AS distance
FROM %s
WHERE page_id = $2
ORDER BY distance ASC
LIMIT %d""" % (self.table_name, limit), embedding, self.page_id)
LIMIT %d"""
% (self.table_name, limit),
embedding,
page_id,
)
"""
Clear temporary index
"""
async def clear_temp(self, in_collection: bool = False, temp_doc_session_id: int = None):
async def clear_temp(
self, in_collection: bool = False, temp_doc_session_id: int = None, page_id: Optional[int] = None
):
sql = "DELETE FROM %s" % (self.table_name)
where = []
params = []
if not in_collection:
params.append(self.page_id)
if not in_collection and page_id:
params.append(page_id)
where.append("page_id = $%d" % len(params))
if temp_doc_session_id:
@ -316,3 +379,8 @@ class PageIndexHelper:
sql += " WHERE " + (" AND ".join(where))
await self.dbi.execute(sql, *params)
async def remove_by_page_id(self, page_id: int):
stmt = delete(self.orm).where(self.orm.page_id == page_id)
await self.session.execute(stmt)
await self.session.commit()

@ -23,11 +23,11 @@ class TitleCollectionHelper(BaseHelper):
self.session.add(obj)
await self.session.commit()
await self.session.refresh(obj)
return obj.id
return obj
return False
return None
async def set_page_id(self, title: str, page_id: Optional[str] = None):
async def set_main_page_id(self, title: str, page_id: Optional[str] = None):
stmt = update(TitleCollectionModel).where(TitleCollectionModel.title == title).values(page_id=page_id)
await self.session.execute(stmt)
await self.session.commit()
@ -37,6 +37,10 @@ class TitleCollectionHelper(BaseHelper):
await self.session.execute(stmt)
await self.session.commit()
async def find_by_id(self, id: int):
stmt = select(TitleCollectionModel).where(TitleCollectionModel.id == id)
return await self.session.scalar(stmt)
async def find_by_title(self, title: str):
stmt = select(TitleCollectionModel).where(TitleCollectionModel.title == title)
return await self.session.scalar(stmt)

@ -5,7 +5,7 @@ import numpy as np
from pgvector.sqlalchemy import Vector
from pgvector.asyncpg import register_vector
import sqlalchemy
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred
from sqlalchemy.orm import mapped_column, relationship, Mapped, deferred, defer
from sqlalchemy.ext.asyncio import AsyncEngine
import config
@ -20,9 +20,9 @@ class TitleIndexModel(BaseModel):
title: Mapped[str] = mapped_column(sqlalchemy.String(255), index=True)
page_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
collection_id: Mapped[int] = mapped_column(sqlalchemy.Integer, index=True)
rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
indexed_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=True)
latest_rev_id: Mapped[int] = mapped_column(sqlalchemy.Integer)
embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE)))
embedding: Mapped[np.ndarray] = deferred(mapped_column(Vector(config.EMBEDDING_VECTOR_SIZE), nullable=True))
embedding_index = sqlalchemy.Index("embedding_search_title_index_embedding_idx", embedding,
postgresql_using='ivfflat',
@ -37,7 +37,7 @@ class TitleIndexHelper(BaseHelper):
"title",
"page_id",
"collection_id",
"rev_id",
"indexed_rev_id",
"latest_rev_id"
"embedding",
]
@ -52,11 +52,11 @@ class TitleIndexHelper(BaseHelper):
await register_vector(self.dbi)
await super().__aenter__()
return await super().__aenter__()
async def __aexit__(self, exc_type, exc, tb):
await self.dbpool.__aexit__(exc_type, exc, tb)
await super().__aexit__(exc_type, exc, tb)
return await super().__aexit__(exc_type, exc, tb)
def get_columns(self, exclude=[]):
if len(exclude) == 0:
@ -92,6 +92,10 @@ class TitleIndexHelper(BaseHelper):
await self.session.commit()
await self.session.refresh(obj)
return obj
async def refresh(self, obj: TitleIndexModel):
await self.session.refresh(obj)
return obj
"""
Search for titles by consine similary
@ -120,6 +124,15 @@ class TitleIndexHelper(BaseHelper):
stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.page_id == page_id)
return await self.session.scalar(stmt)
async def find_list_by_collection_id(self, collection_id: int):
async def find_list_by_collection_id(self, collection_id: int) -> sqlalchemy.ScalarResult[TitleIndexModel]:
stmt = sqlalchemy.select(TitleIndexModel).where(TitleIndexModel.collection_id == collection_id)
return await self.session.scalars(stmt)
return await self.session.scalars(stmt)
async def get_need_update_index_list(self, collection_id: int) -> list[TitleIndexModel]:
page_list = await self.find_list_by_collection_id(collection_id)
result: list[TitleIndexModel] = []
for page_info in page_list:
if page_info.indexed_rev_id != page_info.latest_rev_id:
result.append(page_info)
return result

@ -28,6 +28,7 @@ def init(app: web.Application):
web.route('*', '/embedding_search/index_page', EmbeddingSearch.index_page),
web.route('*', '/embedding_search/search', EmbeddingSearch.search),
web.route('POST', '/sys/embedding_search/title/update', EmbeddingSearch.sys_update_title_info),
web.route('*', '/chatcomplete/conversation_chunk/list', ChatComplete.get_conversation_chunk_list),
web.route('*', '/chatcomplete/conversation_chunk/info', ChatComplete.get_conversation_chunk),

@ -0,0 +1,4 @@
import sys
import pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))

@ -0,0 +1,35 @@
import asyncio
import base as _
import local
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi
async def main():
dbs = await DatabaseService.create()
mw_api = MediaWikiApi.create()
continue_key = None
while True:
page_res = await mw_api.get_all_pages(continue_key)
title_list = page_res["title_list"]
for page_title in title_list:
print("Indexing %s" % page_title)
async with EmbeddingSearchService(dbs, page_title) as embedding_search:
await embedding_search.update_title_index(True)
if not page_res["continue_key"]:
break
continue_key = page_res["continue_key"]
print("Done")
await local.noawait.end()
await asyncio.sleep(1)
if __name__ == '__main__':
local.loop.run_until_complete(main())

@ -10,7 +10,7 @@ pgvector==0.1.6
websockets==11.0
PyJWT==2.6.0
asyncpg-stubs==0.27.0
sqlalchemy==2.0.9
sqlalchemy==2.0.17
aiohttp-sse-client2==0.3.0
OpenCC==1.1.6
event-emitter-asyncio==1.0.4

@ -196,12 +196,12 @@ class ChatCompleteService:
self.conversation_chunk = await self.conversation_chunk_helper.add(self.conversation_chunk)
else:
# 创建新对话
title_info = self.embedding_search.title_info
title_info = self.embedding_search.title_index
self.conversation_info = ConversationModel(
user_id=self.user_id,
module="chatcomplete",
page_id=title_info["page_id"],
rev_id=title_info["rev_id"],
page_id=title_info.page_id,
rev_id=title_info.latest_rev_id,
)
self.conversation_info = await self.conversation_helper.add(
self.conversation_info,

@ -1,6 +1,11 @@
from __future__ import annotations
from typing import Optional, TypedDict
from api.model.embedding_search.title_collection import TitleCollectionHelper, TitleCollectionModel
import sqlalchemy
from api.model.embedding_search.title_collection import (
TitleCollectionHelper,
TitleCollectionModel,
)
from api.model.embedding_search.title_index import TitleIndexHelper, TitleIndexModel
from api.model.embedding_search.page_index import PageIndexHelper
from service.database import DatabaseService
@ -9,14 +14,17 @@ from service.openai_api import OpenAIApi
from service.tiktoken import TikTokenService
from utils.wiki import getWikiSentences
class EmbeddingRunningException(Exception):
pass
class EmbeddingSearchArgs(TypedDict):
limit: Optional[int]
in_collection: Optional[bool]
distance_limit: Optional[float]
class EmbeddingSearchService:
indexing_page_ids: list[int] = []
@ -38,7 +46,7 @@ class EmbeddingSearchService:
self.page_id: Optional[int] = None
self.collection_id: Optional[int] = None
self.title_info: Optional[TitleIndexModel] = None
self.title_index: Optional[TitleIndexModel] = None
self.collection_info: Optional[TitleCollectionModel] = None
self.page_info: dict = None
@ -50,17 +58,22 @@ class EmbeddingSearchService:
await self.title_index_helper.__aenter__()
await self.title_collection_helper.__aenter__()
self.title_info = await self.title_index_helper.find_by_title(self.title)
if self.title_info is not None:
self.page_id = self.title_info.page_id
self.collection_id = self.title_info.collection_id
self.page_index = PageIndexHelper(
self.dbs, self.collection_id, self.page_id)
self.title_index = await self.title_index_helper.find_by_title(self.title)
if self.title_index is None:
# Title may changed, get page info from page_id
await self.load_page_info()
self.title_index = await self.title_index_helper.find_by_page_id(
self.page_info["pageid"]
)
self.page_id = self.page_info["pageid"]
else:
self.page_id = self.title_index.page_id
self.collection_id = self.title_index.collection_id
self.page_index = PageIndexHelper(self.dbs, self.collection_id)
await self.page_index.__aenter__()
return self
async def __aexit__(self, exc_type, exc, tb):
await self.title_index_helper.__aexit__(exc_type, exc, tb)
await self.title_collection_helper.__aexit__(exc_type, exc, tb)
@ -68,7 +81,7 @@ class EmbeddingSearchService:
if self.page_index is not None:
await self.page_index.__aexit__(exc_type, exc, tb)
async def page_index_exists(self, check_table = True):
async def page_index_exists(self, check_table=True):
if check_table:
return self.page_index and await self.page_index.table_exists()
else:
@ -78,47 +91,106 @@ class EmbeddingSearchService:
if self.page_info is None or reload:
self.page_info = await self.mwapi.get_page_info(self.title)
async def should_update_page_index(self):
await self.load_page_info()
async def should_update_page_index(self, remote_update=False):
if not remote_update:
if self.title_index is None:
return True
return self.title_index.indexed_rev_id != self.title_index.latest_rev_id
else:
await self.load_page_info()
if (self.title_info is not None and await self.page_index_exists() and
self.title_info.title == self.page_info["title"] and self.title_info.rev_id == self.page_info["lastrevid"]):
# Not changed
return False
return True
if (
self.title_index is not None
and await self.page_index_exists()
and self.title_index.indexed_rev_id == self.page_info["lastrevid"]
):
# Not changed
return False
async def prepare_update_index(self):
# Check rev_id
await self.load_page_info()
return True
if not await self.should_update_page_index():
async def update_title_index(self, remote_update=False):
if not await self.should_update_page_index(remote_update):
return False
await self.load_page_info()
self.page_id = self.page_info["pageid"]
if self.page_id in self.indexing_page_ids:
raise EmbeddingRunningException("Page index is running now")
# Create collection
self.collection_info = await self.title_collection_helper.find_by_title(self.base_title)
self.title: str = self.page_info["title"]
self.base_title = self.title.split("/")[0]
# Find collection by base title
self.collection_info = await self.title_collection_helper.find_by_title(
self.base_title
)
if self.collection_info is None:
self.collection_id = await self.title_collection_helper.add(self.base_title)
if self.collection_id is None:
# Create collection
self.collection_info = await self.title_collection_helper.add(
self.base_title
)
if self.collection_info is None:
raise Exception("Failed to create title collection")
self.collection_id = self.collection_info.id
if self.title_index == None:
# Create title record
self.title_index = TitleIndexModel(
title=self.page_info["title"],
page_id=self.page_id,
indexed_rev_id=None,
latest_rev_id=self.page_info["lastrevid"],
collection_id=self.collection_id,
embedding=None,
)
self.title_index = await self.title_index_helper.add(self.title_index)
if self.title_index is None:
raise Exception("Failed to create title index")
else:
self.collection_id = self.collection_info.id
self.title_index.latest_rev_id = self.page_info["lastrevid"]
# Title changed, remove embedding
# Title sha1 will be updated by model helper
if self.title_index.title != self.page_info["title"]:
self.title_index.title = self.page_info["title"]
self.title_index.embedding = None
# Collection changed, remove old index
if self.collection_id != self.title_index.collection_id:
async with PageIndexHelper(self.dbs, self.title_index.collection_id) as old_page_index:
await self.page_index.init_table()
old_page_index.remove_by_page_id(self.page_id)
self.title_index.collection_id = self.collection_id
await self.title_index_helper.update(self.title_index)
# Update collection main page id
if (
self.title == self.collection_info.title
and self.page_id != self.collection_info.page_id
):
await self.title_collection_helper.set_main_page_id(
self.base_title, self.page_id
)
self.page_index = PageIndexHelper(
self.dbs, self.collection_id, self.page_id)
if self.page_index:
await self.page_index.__aexit__(None, None, None)
self.page_index = PageIndexHelper(self.dbs, self.collection_id)
await self.page_index.__aenter__()
await self.page_index.init_table()
async def prepare_update_index(self):
await self.update_title_index()
page_content = await self.mwapi.parse_page(self.title)
self.sentences = getWikiSentences(page_content)
self.unindexed_docs = await self.page_index.get_unindexed_doc(self.sentences, with_temporary=False)
self.unindexed_docs = await self.page_index.get_unindexed_doc(
self.sentences, self.page_id, with_temporary=False
)
return True
@ -153,15 +225,17 @@ class EmbeddingSearchService:
await on_progress(indexed_docs, len(self.unindexed_docs))
async def embedding_doc(doc_chunk):
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk, on_embedding_progress)
await self.page_index.index_doc(doc_chunk)
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(
doc_chunk, on_embedding_progress
)
await self.page_index.index_doc(doc_chunk, self.page_id)
return token_usage
if len(self.unindexed_docs) > 0:
if on_progress is not None:
await on_progress(0, len(self.unindexed_docs))
for doc in self.unindexed_docs:
chunk_len += len(doc)
@ -181,55 +255,34 @@ class EmbeddingSearchService:
if on_progress is not None:
await on_progress(len(self.unindexed_docs), len(self.unindexed_docs))
await self.page_index.remove_outdated_doc(self.sentences)
await self.page_index.remove_outdated_doc(self.sentences, self.page_id)
# Update database
if self.title_info is None:
# This task may take a long time, refresh model to retrieve latest data
self.title_index = await self.title_index_helper.refresh(self.title_index)
self.title_index.indexed_rev_id = self.page_info["lastrevid"]
# Update title embedding
if await self.title_index.awaitable_attrs.embedding is None:
doc_chunk = [{"text": self.title}]
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
total_token_usage += token_usage
embedding = doc_chunk[0]["embedding"]
self.title_index.embedding = embedding
self.title_info = TitleIndexModel(
title=self.page_info["title"],
page_id=self.page_id,
rev_id=self.page_info["lastrevid"],
latest_rev_id=self.page_info["lastrevid"],
collection_id=self.collection_id,
embedding=embedding
)
res = await self.title_index_helper.add(self.title_info)
if res:
self.title_info = res
else:
if self.title != self.page_info["title"]:
self.title = self.page_info["title"]
doc_chunk = [{"text": self.title}]
(doc_chunk, token_usage) = await self.openai_api.get_embeddings(doc_chunk)
total_token_usage += token_usage
embedding = doc_chunk[0]["embedding"]
self.title_info.title = self.title
self.title_info.rev_id = self.page_info["lastrevid"]
self.title_info.latest_rev_id = self.page_info["lastrevid"]
self.title_info.collection_id = self.collection_id
self.title_info.embedding = embedding
else:
self.title_info.rev_id = self.page_info["lastrevid"]
self.title_info.latest_rev_id = self.page_info["lastrevid"]
await self.title_index_helper.update(self.title_info)
if (self.collection_info is None or
(self.base_title == self.collection_info.title and self.page_id != self.collection_info.page_id)):
await self.title_collection_helper.set_page_id(self.base_title, self.page_id)
await self.title_index_helper.update(self.title_index)
return total_token_usage
async def search(self, query: str, limit: int = 10, in_collection: bool = False, distance_limit: float = 0.6):
async def search(
self,
query: str,
limit: int = 10,
in_collection: bool = False,
distance_limit: float = 0.6,
):
if self.page_index is None:
raise Exception("Page index is not initialized")
@ -240,7 +293,9 @@ class EmbeddingSearchService:
if query_embedding is None:
return [], token_usage
res = await self.page_index.search_text_embedding(query_embedding, in_collection, limit)
res = await self.page_index.search_text_embedding(
query_embedding, in_collection, limit, self.page_id
)
if res:
filtered = []
for one in res:

@ -35,6 +35,10 @@ class MediaWikiUserNoEnoughPointsException(Exception):
def __str__(self) -> str:
return self.info
class GetAllPagesResponse(TypedDict):
title_list: list[str]
continue_key: Optional[str]
class ChatCompleteGetPointUsageResponse(TypedDict):
point_cost: int
@ -125,6 +129,38 @@ class MediaWikiApi:
return ret
async def get_all_pages(self, continue_key: Optional[str] = None, start_title: Optional[str] = None) -> GetAllPagesResponse:
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {
"action": "query",
"format": "json",
"formatversion": "2",
"list": "allpages",
"apnamespace": 0,
"apfilterredir": "nonredirects",
"aplimit": 100,
}
if continue_key is not None:
params["apcontinue"] = continue_key
if start_title is not None:
params["apfrom"] = start_title
async with session.get(self.api_url, params=params, proxy=config.REQUEST_PROXY) as resp:
data = await resp.json()
if "error" in data:
raise MediaWikiApiException(data["error"]["info"], data["error"]["code"])
title_list = []
for page in data["query"]["allpages"]:
title_list.append(page["title"])
ret = GetAllPagesResponse(title_list=title_list, continue_key=None)
if "continue" in data and "apcontinue" in data["continue"]:
ret["continue_key"] = data["continue"]["apcontinue"]
return ret
async def is_logged_in(self):
async with aiohttp.ClientSession(cookie_jar=self.cookie_jar) as session:
params = {

Loading…
Cancel
Save