You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
229 lines
8.8 KiB
Python
229 lines
8.8 KiB
Python
import sys
|
|
import traceback
|
|
from aiohttp import web
|
|
from service.database import DatabaseService
|
|
from service.embedding_search import EmbeddingSearchService
|
|
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
|
|
import utils.web
|
|
|
|
class EmbeddingSearch:
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def index_page(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"title": {
|
|
"required": True,
|
|
},
|
|
})
|
|
|
|
page_title = params.get('title')
|
|
|
|
mwapi = MediaWikiApi.create()
|
|
db = await DatabaseService.create(request.app)
|
|
# Detect is WebSocket
|
|
if utils.web.is_websocket(request):
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(request)
|
|
try:
|
|
transatcion_id = None
|
|
|
|
async with EmbeddingSearchService(db, page_title) as embedding_search:
|
|
if await embedding_search.should_update_page_index():
|
|
if request.get("caller") == "user":
|
|
user_id = request.get("user")
|
|
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
|
|
|
|
await embedding_search.prepare_update_index()
|
|
|
|
async def on_progress(current, total):
|
|
|
|
await ws.send_json({
|
|
'event': 'progress',
|
|
'current': current,
|
|
'total': total
|
|
})
|
|
|
|
token_usage = await embedding_search.update_page_index(on_progress)
|
|
await ws.send_json({
|
|
'event': 'done',
|
|
'status': 1,
|
|
'index_updated': True
|
|
})
|
|
|
|
if transatcion_id:
|
|
await mwapi.chat_complete_end_transaction(transatcion_id, token_usage)
|
|
else:
|
|
await ws.send_json({
|
|
'event': 'done',
|
|
'status': 1,
|
|
'index_updated': False
|
|
})
|
|
except MediaWikiPageNotFoundException:
|
|
error_msg = "Page \"%s\" not found." % page_title
|
|
await ws.send_json({
|
|
'event': 'error',
|
|
'status': -2,
|
|
'message': error_msg,
|
|
'error': {
|
|
'code': 'page_not_found',
|
|
'title': page_title,
|
|
},
|
|
})
|
|
if transatcion_id:
|
|
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
|
except MediaWikiApiException as e:
|
|
error_msg = "MediaWiki API error: %s" % str(e)
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
await ws.send_json({
|
|
'event': 'error',
|
|
'status': -3,
|
|
'message': error_msg,
|
|
'error': {
|
|
'code': e.code,
|
|
'info': e.info,
|
|
},
|
|
})
|
|
if transatcion_id:
|
|
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
await ws.send_json({
|
|
'event': 'error',
|
|
'status': -1,
|
|
'message': error_msg
|
|
})
|
|
if transatcion_id:
|
|
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
|
finally:
|
|
await ws.close()
|
|
else:
|
|
try:
|
|
transatcion_id = None
|
|
|
|
async with EmbeddingSearchService(db, page_title) as embedding_search:
|
|
if await embedding_search.should_update_page_index():
|
|
if request.get("caller") == "user":
|
|
user_id = request.get("user")
|
|
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
|
|
|
|
await embedding_search.prepare_update_index()
|
|
|
|
token_usage = await embedding_search.update_page_index()
|
|
|
|
if transatcion_id:
|
|
result = await mwapi.chat_complete_end_transaction(transatcion_id, token_usage)
|
|
|
|
return await utils.web.api_response(1, {"data_indexed": True})
|
|
else:
|
|
return await utils.web.api_response(1, {"data_indexed": False})
|
|
except MediaWikiPageNotFoundException:
|
|
error_msg = "Page \"%s\" not found." % page_title
|
|
if transatcion_id:
|
|
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
|
|
|
return await utils.web.api_response(-2, error={
|
|
"code": "page-not-found",
|
|
"title": page_title,
|
|
"message": error_msg
|
|
}, http_status=404)
|
|
except MediaWikiApiException as e:
|
|
error_msg = "MediaWiki API error: %s" % e.info
|
|
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
|
|
if transatcion_id:
|
|
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
|
|
|
return await utils.web.api_response(-3, error={
|
|
"code": "mediawiki-api-error",
|
|
"info": e.info,
|
|
"message": error_msg
|
|
}, http_status=500)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
|
|
if transatcion_id:
|
|
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "internal-server-error",
|
|
"message": error_msg
|
|
}, http_status=500)
|
|
|
|
@staticmethod
|
|
@utils.web.token_auth
|
|
async def search(request: web.Request):
|
|
params = await utils.web.get_param(request, {
|
|
"title": {
|
|
"required": True,
|
|
},
|
|
"query": {
|
|
"required": True
|
|
},
|
|
"limit": {
|
|
"required": False,
|
|
"type": int,
|
|
"default": 5
|
|
},
|
|
"incollection": {
|
|
"required": False,
|
|
"type": bool,
|
|
"default": False
|
|
},
|
|
"distancelimit": {
|
|
"required": False,
|
|
"type": float,
|
|
"default": 0.6
|
|
},
|
|
})
|
|
|
|
page_title = params.get('title')
|
|
query = params.get('query')
|
|
limit = params.get('limit')
|
|
in_collection = params.get('incollection')
|
|
distance_limit = params.get('distancelimit')
|
|
|
|
limit = min(limit, 10)
|
|
|
|
db = await DatabaseService.create(request.app)
|
|
|
|
try:
|
|
async with EmbeddingSearchService(db, page_title) as embedding_search:
|
|
results, token_usage = await embedding_search.search(query, limit, in_collection, distance_limit)
|
|
except MediaWikiPageNotFoundException:
|
|
error_msg = "Page \"%s\" not found." % page_title
|
|
|
|
return await utils.web.api_response(-2, error={
|
|
"code": "page-not-found",
|
|
"title": page_title,
|
|
"message": error_msg
|
|
}, request=request, http_status=404)
|
|
except MediaWikiApiException as e:
|
|
error_msg = "MediaWiki API error: %s" % e.info
|
|
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
|
|
return await utils.web.api_response(-3, error={
|
|
"code": "mediawiki-api-error",
|
|
"info": e.info,
|
|
"message": error_msg
|
|
}, request=request, http_status=500)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
|
|
print(error_msg, file=sys.stderr)
|
|
traceback.print_exc()
|
|
|
|
return await utils.web.api_response(-1, error={
|
|
"code": "internal-server-error",
|
|
"message": error_msg
|
|
}, request=request, http_status=500)
|
|
return await utils.web.api_response(1, data={"results": results, "token_usage": token_usage}, request=request) |