You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
325 lines
13 KiB
Python
325 lines
13 KiB
Python
import sys
|
|
import traceback
|
|
from aiohttp import web
|
|
from libs.config import Config
|
|
from server.model.embedding_search.title_collection import TitleCollectionHelper
|
|
from server.model.embedding_search.title_index import TitleIndexHelper
|
|
from service.database import DatabaseService
|
|
from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService
|
|
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
|
|
import utils.web
|
|
|
|
class EmbeddingSearch:
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def index_page(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"title": {
|
|
"required": True,
|
|
},
|
|
"collection": {
|
|
"required": False,
|
|
"type": bool,
|
|
"default": False
|
|
}
|
|
})
|
|
|
|
page_title = params.get("title")
|
|
is_collection = params.get("collection")
|
|
|
|
mwapi = MediaWikiApi.create()
|
|
db = await DatabaseService.create(request.app)
|
|
# Detect is WebSocket
|
|
if utils.web.is_websocket(request):
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(request)
|
|
try:
|
|
transatcion_id = None
|
|
|
|
title_list = [page_title]
|
|
if is_collection:
|
|
# Get collection titles
|
|
async with TitleCollectionHelper(db) as title_collection, TitleIndexHelper(db) as title_index_helper:
|
|
title_collection = await title_collection.find_by_title(page_title)
|
|
if title_collection is not None:
|
|
need_update_pages = await title_index_helper.get_need_update_index_list(title_collection.id)
|
|
title_list = []
|
|
for page_info in need_update_pages:
|
|
title_list.append(page_info.title)
|
|
|
|
page_count = len(title_list)
|
|
page_current = 0
|
|
index_updated = False
|
|
for one_title in title_list:
|
|
page_current += 1
|
|
async with EmbeddingSearchService(db, one_title) as embedding_search:
|
|
if await embedding_search.should_update_page_index():
|
|
# if request.get("caller") == "user":
|
|
# user_id = request.get("user")
|
|
# usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
|
|
# transatcion_id = usage_res.get("transaction_id")
|
|
|
|
await embedding_search.prepare_update_index()
|
|
|
|
async def on_progress(current, total):
|
|
await ws.send_json({
|
|
"event": "progress",
|
|
"current": current,
|
|
"total": total,
|
|
"current_page": page_current,
|
|
"total_page": page_count,
|
|
})
|
|
|
|
token_usage = await embedding_search.update_page_index(on_progress)
|
|
index_updated = True
|
|
|
|
if transatcion_id:
|
|
await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
|
|
|
|
await ws.send_json({
|
|
"event": "done",
|
|
"status": 1,
|
|
"index_updated": index_updated,
|
|
})
|
|
except MediaWikiPageNotFoundException:
|
|
error_msg = "Page \"%s\" not found." % page_title
|
|
await ws.send_json({
|
|
"event": "error",
|
|
"status": -2,
|
|
"message": error_msg,
|
|
"error": {
|
|
"code": "page-not-found",
|
|
"title": page_title,
|
|
},
|
|
})
|
|
if transatcion_id:
|
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
|
except MediaWikiApiException as e:
|
|
error_msg = "MediaWiki API error: %s" % str(e)
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
await ws.send_json({
|
|
"event": "error",
|
|
"status": -3,
|
|
"message": error_msg,
|
|
"error": {
|
|
"code": e.code,
|
|
"info": e.info,
|
|
},
|
|
})
|
|
if transatcion_id:
|
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
|
except EmbeddingRunningException:
|
|
error_msg = "Page index is running now"
|
|
await ws.send_json({
|
|
"event": "error",
|
|
"status": -4,
|
|
"message": error_msg,
|
|
"error": {
|
|
"code": "page-index-running",
|
|
},
|
|
})
|
|
if transatcion_id:
|
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
|
except ConnectionResetError:
|
|
pass # Ignore websocket close error
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
await ws.send_json({
|
|
"event": "error",
|
|
"status": -1,
|
|
"message": error_msg,
|
|
"error": {
|
|
"code": "internal-server-error",
|
|
}
|
|
})
|
|
if transatcion_id:
|
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
|
finally:
|
|
await ws.close()
|
|
else:
|
|
try:
|
|
transatcion_id = None
|
|
|
|
async with EmbeddingSearchService(db, page_title) as embedding_search:
|
|
if await embedding_search.should_update_page_index():
|
|
# if request.get("caller") == "user":
|
|
# user_id = request.get("user")
|
|
# usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
|
|
# transatcion_id = usage_res.get("transaction_id")
|
|
|
|
await embedding_search.prepare_update_index()
|
|
|
|
token_usage = await embedding_search.update_page_index()
|
|
|
|
if transatcion_id:
|
|
result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
|
|
|
|
return await utils.web.api_response(1, {
|
|
"data_indexed": True,
|
|
"title": embedding_search.title
|
|
})
|
|
else:
|
|
return await utils.web.api_response(1, {
|
|
"data_indexed": False,
|
|
"title": embedding_search.title
|
|
})
|
|
except MediaWikiPageNotFoundException:
|
|
error_msg = "Page \"%s\" not found." % page_title
|
|
if transatcion_id:
|
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
|
|
|
return await utils.web.api_response(-2, error={
|
|
"code": "page-not-found",
|
|
"title": page_title,
|
|
"message": error_msg
|
|
}, http_status=404)
|
|
except MediaWikiApiException as e:
|
|
error_msg = "MediaWiki API error: %s" % e.info
|
|
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
|
|
if transatcion_id:
|
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
|
|
|
return await utils.web.api_response(-3, error={
|
|
"code": "mediawiki-api-error",
|
|
"info": e.info,
|
|
"message": error_msg
|
|
}, http_status=500)
|
|
except EmbeddingRunningException:
|
|
error_msg = "Page index is running now"
|
|
|
|
if transatcion_id:
|
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
|
|
|
return await utils.web.api_response(-4, error={
|
|
"code": "page-index-running",
|
|
"message": error_msg
|
|
}, http_status=429)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
|
|
if transatcion_id:
|
|
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "internal-server-error",
|
|
"message": error_msg
|
|
}, http_status=500)
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def search(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"title": {
|
|
"required": True,
|
|
},
|
|
"query": {
|
|
"required": True
|
|
},
|
|
"limit": {
|
|
"required": False,
|
|
"type": int,
|
|
"default": 5
|
|
},
|
|
"incollection": {
|
|
"required": False,
|
|
"type": bool,
|
|
"default": False
|
|
},
|
|
"distancelimit": {
|
|
"required": False,
|
|
"type": float,
|
|
"default": None
|
|
},
|
|
})
|
|
|
|
page_title = params.get('title')
|
|
query = params.get('query')
|
|
limit = params.get('limit')
|
|
in_collection = params.get('incollection')
|
|
distance_limit = params.get('distancelimit')
|
|
|
|
limit = min(limit, 10)
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
|
|
try:
|
|
async with EmbeddingSearchService(db, page_title) as embedding_search:
|
|
results, token_usage = await embedding_search.search(query, limit, in_collection, distance_limit)
|
|
except MediaWikiPageNotFoundException:
|
|
error_msg = "Page \"%s\" not found." % page_title
|
|
|
|
return await utils.web.api_response(-2, error={
|
|
"code": "page-not-found",
|
|
"title": page_title,
|
|
"message": error_msg
|
|
}, request=request, http_status=404)
|
|
except MediaWikiApiException as e:
|
|
error_msg = "MediaWiki API error: %s" % e.info
|
|
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
|
|
return await utils.web.api_response(-3, error={
|
|
"code": "mediawiki-api-error",
|
|
"info": e.info,
|
|
"message": error_msg
|
|
}, request=request, http_status=500)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "internal-server-error",
|
|
"message": error_msg
|
|
}, request=request, http_status=500)
|
|
return await utils.web.api_response(1, data={"results": results, "token_usage": token_usage}, request=request)
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def sys_update_title_info(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"title": {
|
|
"required": True,
|
|
},
|
|
})
|
|
|
|
if request.get("caller") == "user":
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "permission-denied",
|
|
"message": "This api is only for system caller."
|
|
}, request=request, http_status=403)
|
|
|
|
page_title = params.get("title")
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
|
|
async with EmbeddingSearchService(db, page_title) as embedding_search:
|
|
try:
|
|
await embedding_search.update_title_index(True)
|
|
except MediaWikiPageNotFoundException:
|
|
error_msg = "Page \"%s\" not found." % page_title
|
|
|
|
return await utils.web.api_response(-2, error={
|
|
"code": "page-not-found",
|
|
"title": page_title,
|
|
"message": error_msg
|
|
}, request=request, http_status=404)
|
|
except Exception as err:
|
|
err_msg = str(err)
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "internal-server-error",
|
|
"message": err_msg
|
|
}, request=request, http_status=500)
|
|
|
|
return await utils.web.api_response(1, request=request) |