You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

229 lines
8.8 KiB
Python

import sys
import traceback
from aiohttp import web
from service.database import DatabaseService
from service.embedding_search import EmbeddingSearchService
from service.mediawiki_api import MediaWikiApi, MediaWikiApiException, MediaWikiPageNotFoundException
import utils.web
class EmbeddingSearch:
@staticmethod
@utils.web.token_auth
async def index_page(request: web.Request):
params = await utils.web.get_param(request, {
"title": {
"required": True,
},
})
page_title = params.get('title')
mwapi = MediaWikiApi.create()
db = await DatabaseService.create(request.app)
# Detect is WebSocket
if utils.web.is_websocket(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
try:
transatcion_id = None
async with EmbeddingSearchService(db, page_title) as embedding_search:
if await embedding_search.should_update_page_index():
if request.get("caller") == "user":
user_id = request.get("user")
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
await embedding_search.prepare_update_index()
async def on_progress(current, total):
await ws.send_json({
'event': 'progress',
'current': current,
'total': total
})
token_usage = await embedding_search.update_page_index(on_progress)
await ws.send_json({
'event': 'done',
'status': 1,
'index_updated': True
})
if transatcion_id:
await mwapi.chat_complete_end_transaction(transatcion_id, token_usage)
else:
await ws.send_json({
'event': 'done',
'status': 1,
'index_updated': False
})
except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title
await ws.send_json({
'event': 'error',
'status': -2,
'message': error_msg,
'error': {
'code': 'page_not_found',
'title': page_title,
},
})
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
except MediaWikiApiException as e:
error_msg = "MediaWiki API error: %s" % str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
await ws.send_json({
'event': 'error',
'status': -3,
'message': error_msg,
'error': {
'code': e.code,
'info': e.info,
},
})
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
except Exception as e:
error_msg = str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
await ws.send_json({
'event': 'error',
'status': -1,
'message': error_msg
})
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
finally:
await ws.close()
else:
try:
transatcion_id = None
async with EmbeddingSearchService(db, page_title) as embedding_search:
if await embedding_search.should_update_page_index():
if request.get("caller") == "user":
user_id = request.get("user")
transatcion_id = await mwapi.chat_complete_start_transaction(user_id, "embeddingpage")
await embedding_search.prepare_update_index()
token_usage = await embedding_search.update_page_index()
if transatcion_id:
result = await mwapi.chat_complete_end_transaction(transatcion_id, token_usage)
return await utils.web.api_response(1, {"data_indexed": True})
else:
return await utils.web.api_response(1, {"data_indexed": False})
except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
return await utils.web.api_response(-2, error={
"code": "page-not-found",
"title": page_title,
"message": error_msg
}, http_status=404)
except MediaWikiApiException as e:
error_msg = "MediaWiki API error: %s" % e.info
print(error_msg, file=sys.stderr)
traceback.print_exc()
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
return await utils.web.api_response(-3, error={
"code": "mediawiki-api-error",
"info": e.info,
"message": error_msg
}, http_status=500)
except Exception as e:
error_msg = str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
if transatcion_id:
await mwapi.chat_complete_cancel_transaction(transatcion_id, error_msg)
return await utils.web.api_response(-1, error={
"code": "internal-server-error",
"message": error_msg
}, http_status=500)
@staticmethod
@utils.web.token_auth
async def search(request: web.Request):
params = await utils.web.get_param(request, {
"title": {
"required": True,
},
"query": {
"required": True
},
"limit": {
"required": False,
"type": int,
"default": 5
},
"incollection": {
"required": False,
"type": bool,
"default": False
},
"distancelimit": {
"required": False,
"type": float,
"default": 0.6
},
})
page_title = params.get('title')
query = params.get('query')
limit = params.get('limit')
in_collection = params.get('incollection')
distance_limit = params.get('distancelimit')
limit = min(limit, 10)
db = await DatabaseService.create(request.app)
try:
async with EmbeddingSearchService(db, page_title) as embedding_search:
results, token_usage = await embedding_search.search(query, limit, in_collection, distance_limit)
except MediaWikiPageNotFoundException:
error_msg = "Page \"%s\" not found." % page_title
return await utils.web.api_response(-2, error={
"code": "page-not-found",
"title": page_title,
"message": error_msg
}, request=request, http_status=404)
except MediaWikiApiException as e:
error_msg = "MediaWiki API error: %s" % e.info
print(error_msg, file=sys.stderr)
traceback.print_exc()
return await utils.web.api_response(-3, error={
"code": "mediawiki-api-error",
"info": e.info,
"message": error_msg
}, request=request, http_status=500)
except Exception as e:
error_msg = str(e)
print(error_msg, file=sys.stderr)
traceback.print_exc()
return await utils.web.api_response(-1, error={
"code": "internal-server-error",
"message": error_msg
}, request=request, http_status=500)
return await utils.web.api_response(1, data={"results": results, "token_usage": token_usage}, request=request)