You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

325 lines
13 KiB
Python

import sys
import traceback
from aiohttp import web
from libs.config import Config
from server.model.embedding_search.title_collection import TitleCollectionHelper
from server.model.embedding_search.title_index import TitleIndexHelper
from service.database import DatabaseService
from service.embedding_search import EmbeddingRunningException, EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
import utils.web
class EmbeddingSearch:
@staticmethod
@utils.web.token_auth
async def index_page(request: web.Request):
params = await utils.web.get_param(request, {
"title": {
"required": True,
},
"collection": {
"required": False,
"type": bool,
"default": False
}
})
page_title = params.get("title")
is_collection = params.get("collection")
mwapi = MediaWikiApi.create()
db = await DatabaseService.create(request.app)
# Detect is WebSocket
if utils.web.is_websocket(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
try:
transatcion_id = None
title_list = [page_title]
if is_collection:
# Get collection titles
async with TitleCollectionHelper(db) as title_collection, TitleIndexHelper(db) as title_index_helper:
title_collection = await title_collection.find_by_title(page_title)
if title_collection is not None:
need_update_pages = await title_index_helper.get_need_update_index_list(title_collection.id)
title_list = []
for page_info in need_update_pages:
title_list.append(page_info.title)
page_count = len(title_list)
page_current = 0
index_updated = False
for one_title in title_list:
page_current += 1
async with EmbeddingSearchService(db, one_title) as embedding_search:
if await embedding_search.should_update_page_index():
# if request.get("caller") == "user":
# user_id = request.get("user")
# usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
# transatcion_id = usage_res.get("transaction_id")
await embedding_search.prepare_update_index()
async def on_progress(current, total):
await ws.send_json({
"event": "progress",
"current": current,
"total": total,
"current_page": page_current,
"total_page": page_count,
})
token_usage = await embedding_search.update_page_index(on_progress)
index_updated = True
if transatcion_id:
await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
await ws.send_json({
"event": "done",
"status": 1,
"index_updated": index_updated,
})
except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title
await ws.send_json({
"event": "error",
"status": -2,
"message": error_msg,
"error": {
"code": "page-not-found",
"title": page_title,
},
})
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
except MediaWikiApiException as e:
error_msg = "MediaWiki API error: %s" % str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
await ws.send_json({
"event": "error",
"status": -3,
"message": error_msg,
"error": {
"code": e.code,
"info": e.info,
},
})
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
except EmbeddingRunningException:
error_msg = "Page index is running now"
await ws.send_json({
"event": "error",
"status": -4,
"message": error_msg,
"error": {
"code": "page-index-running",
},
})
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
except ConnectionResetError:
pass # Ignore websocket close error
except Exception as e:
error_msg = str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
await ws.send_json({
"event": "error",
"status": -1,
"message": error_msg,
"error": {
"code": "internal-server-error",
}
})
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
finally:
await ws.close()
else:
try:
transatcion_id = None
async with EmbeddingSearchService(db, page_title) as embedding_search:
if await embedding_search.should_update_page_index():
# if request.get("caller") == "user":
# user_id = request.get("user")
# usage_res = await mwapi.ai_toolbox_start_transaction(user_id, "embeddingpage")
# transatcion_id = usage_res.get("transaction_id")
await embedding_search.prepare_update_index()
token_usage = await embedding_search.update_page_index()
if transatcion_id:
result = await mwapi.ai_toolbox_end_transaction(transatcion_id, token_usage)
return await utils.web.api_response(1, {
"data_indexed": True,
"title": embedding_search.title
})
else:
return await utils.web.api_response(1, {
"data_indexed": False,
"title": embedding_search.title
})
except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
return await utils.web.api_response(-2, error={
"code": "page-not-found",
"title": page_title,
"message": error_msg
}, http_status=404)
except MediaWikiApiException as e:
error_msg = "MediaWiki API error: %s" % e.info
print(error_msg, file=sys.stderr)
traceback.print_exc()
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
return await utils.web.api_response(-3, error={
"code": "mediawiki-api-error",
"info": e.info,
"message": error_msg
}, http_status=500)
except EmbeddingRunningException:
error_msg = "Page index is running now"
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
return await utils.web.api_response(-4, error={
"code": "page-index-running",
"message": error_msg
}, http_status=429)
except Exception as e:
error_msg = str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
if transatcion_id:
await mwapi.ai_toolbox_cancel_transaction(transatcion_id, error_msg)
return await utils.web.api_response(-1, error={
"code": "internal-server-error",
"message": error_msg
}, http_status=500)
@staticmethod
@utils.web.token_auth
async def search(request: web.Request):
params = await utils.web.get_param(request, {
"title": {
"required": True,
},
"query": {
"required": True
},
"limit": {
"required": False,
"type": int,
"default": 5
},
"incollection": {
"required": False,
"type": bool,
"default": False
},
"distancelimit": {
"required": False,
"type": float,
"default": None
},
})
page_title = params.get('title')
query = params.get('query')
limit = params.get('limit')
in_collection = params.get('incollection')
distance_limit = params.get('distancelimit')
limit = min(limit, 10)
db = await DatabaseService.create(request.app)
try:
async with EmbeddingSearchService(db, page_title) as embedding_search:
results, token_usage = await embedding_search.search(query, limit, in_collection, distance_limit)
except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title
return await utils.web.api_response(-2, error={
"code": "page-not-found",
"title": page_title,
"message": error_msg
}, request=request, http_status=404)
except MediaWikiApiException as e:
error_msg = "MediaWiki API error: %s" % e.info
print(error_msg, file=sys.stderr)
traceback.print_exc()
return await utils.web.api_response(-3, error={
"code": "mediawiki-api-error",
"info": e.info,
"message": error_msg
}, request=request, http_status=500)
except Exception as e:
error_msg = str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
return await utils.web.api_response(-1, error={
"code": "internal-server-error",
"message": error_msg
}, request=request, http_status=500)
return await utils.web.api_response(1, data={"results": results, "token_usage": token_usage}, request=request)
@staticmethod
@utils.web.token_auth
async def sys_update_title_info(request: web.Request):
params = await utils.web.get_param(request, {
"title": {
"required": True,
},
})
if request.get("caller") == "user":
return await utils.web.api_response(-1, error={
"code": "permission-denied",
"message": "This api is only for system caller."
}, request=request, http_status=403)
page_title = params.get("title")
db = await DatabaseService.create(request.app)
async with EmbeddingSearchService(db, page_title) as embedding_search:
try:
await embedding_search.update_title_index(True)
except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title
return await utils.web.api_response(-2, error={
"code": "page-not-found",
"title": page_title,
"message": error_msg
}, request=request, http_status=404)
except Exception as err:
err_msg = str(err)
return await utils.web.api_response(-1, error={
"code": "internal-server-error",
"message": err_msg
}, request=request, http_status=500)
return await utils.web.api_response(1, request=request)