You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
196 lines
7.9 KiB
Python
196 lines
7.9 KiB
Python
from __future__ import annotations
|
|
from functools import wraps
|
|
import json
|
|
from typing import Any, Optional, Dict
|
|
from aiohttp import web
|
|
import jwt
|
|
import uuid
|
|
from config import Config
|
|
|
|
ParamRule = Dict[str, Any]
|
|
|
|
class ParamInvalidException(web.HTTPBadRequest):
|
|
def __init__(self, param_list: list[str], rules: dict[str, ParamRule]):
|
|
self.code = "param_invalid"
|
|
self.param_list = param_list
|
|
self.rules = rules
|
|
param_list_str = "'" + ("', '".join(param_list)) + "'"
|
|
super().__init__(f"Param invalid: {param_list_str}")
|
|
|
|
async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None):
|
|
params: dict[str, Any] = {}
|
|
for key, value in request.match_info.items():
|
|
params[key] = value
|
|
for key, value in request.query.items():
|
|
params[key] = value
|
|
if request.method == 'POST':
|
|
if request.headers.get('content-type') == 'application/json':
|
|
data = await request.json()
|
|
if data is not None and isinstance(data, dict):
|
|
for key, value in data.items():
|
|
params[key] = value
|
|
else:
|
|
data = await request.post()
|
|
for key, value in data.items():
|
|
params[key] = value
|
|
|
|
if rules is not None:
|
|
invalid_params: list[str] = []
|
|
for key, rule in rules.items():
|
|
if "required" in rule and rule["required"] and key not in params.keys():
|
|
invalid_params.append(key)
|
|
continue
|
|
|
|
if key in params:
|
|
if "type" in rule:
|
|
if rule["type"] is dict:
|
|
if params[key] not in rule["type"]:
|
|
invalid_params.append(key)
|
|
continue
|
|
try:
|
|
if rule["type"] == int:
|
|
params[key] = int(params[key])
|
|
elif rule["type"] == float:
|
|
params[key] = float(params[key])
|
|
elif rule["type"] == bool:
|
|
val = params[key]
|
|
if isinstance(val, bool):
|
|
params[key] = val
|
|
elif isinstance(val, str):
|
|
val = val.lower()
|
|
if val.lower() == "false" or val == "0":
|
|
params[key] = False
|
|
else:
|
|
params[key] = True
|
|
elif isinstance(val, int):
|
|
if val == 0:
|
|
params[key] = False
|
|
else:
|
|
params[key] = True
|
|
else:
|
|
params[key] = True
|
|
except ValueError:
|
|
invalid_params.append(key)
|
|
continue
|
|
else:
|
|
if "default" in rule:
|
|
params[key] = rule["default"]
|
|
else:
|
|
params[key] = None
|
|
|
|
if len(invalid_params) > 0:
|
|
raise ParamInvalidException(invalid_params, rules)
|
|
|
|
return params
|
|
|
|
async def api_response(status, data=None, error=None, warning=None, http_status=200, request: Optional[web.Request] = None):
|
|
ret = { "status": status }
|
|
if data:
|
|
ret["data"] = data
|
|
if error:
|
|
ret["error"] = error
|
|
if warning:
|
|
ret["warning"] = warning
|
|
if request and is_websocket(request):
|
|
ret["event"] = "response"
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(request)
|
|
await ws.send_json(ret)
|
|
await ws.close()
|
|
else:
|
|
return web.json_response(ret, status=http_status)
|
|
|
|
def is_websocket(request: web.Request):
|
|
return request.headers.get('Upgrade', '').lower() == 'websocket'
|
|
|
|
def generate_uuid():
|
|
return str(uuid.uuid4())
|
|
|
|
# Auth decorator
|
|
def token_auth(f):
|
|
@wraps(f)
|
|
def decorated_function(*args, **kwargs):
|
|
async def async_wrapper(*args, **kwargs):
|
|
auth_tokens: dict = Config.get("authorization")
|
|
request: web.Request = args[0]
|
|
|
|
jwt_token = None
|
|
sk_token = None
|
|
params = await get_param(request)
|
|
token = params.get("token")
|
|
if token:
|
|
jwt_token = token
|
|
else:
|
|
token: str = request.headers.get('Authorization') or request.headers.get('authorization')
|
|
if token is None:
|
|
return await api_response(status=-1, error={
|
|
"code": "missing-token",
|
|
"message": "Missing token."
|
|
}, http_status=401, request=request)
|
|
token = token.replace("Bearer ", "")
|
|
if token.startswith("sk-"):
|
|
sk_token = token
|
|
else:
|
|
jwt_token = token
|
|
|
|
if sk_token is not None:
|
|
if sk_token not in auth_tokens.values():
|
|
return await api_response(status=-1, error={
|
|
"code": "token-invalid",
|
|
"target": "token_id",
|
|
"message": "Token invalid."
|
|
}, http_status=401, request=request)
|
|
|
|
if "user_id" in params:
|
|
request["user"] = params["user_id"]
|
|
else:
|
|
request["user"] = 0
|
|
|
|
request["caller"] = "server"
|
|
elif jwt_token is not None:
|
|
# Get appid from jwt header
|
|
try:
|
|
jwt_header = jwt.get_unverified_header(jwt_token)
|
|
key_id: str = jwt_header["kid"]
|
|
except (KeyError, jwt.exceptions.DecodeError):
|
|
return await api_response(status=-1, error={
|
|
"code": "token-invalid",
|
|
"target": "token_id",
|
|
"message": "Token issuer not exists."
|
|
}, http_status=401, request=request)
|
|
|
|
# Check jwt
|
|
try:
|
|
data = jwt.decode(jwt_token, auth_tokens[key_id], algorithms=['HS256', 'HS384', 'HS512'])
|
|
if "sub" not in data:
|
|
return await api_response(status=-1, error={
|
|
"code": "token-invalid",
|
|
"target": "subject",
|
|
"message": "Token subject invalid."
|
|
}, http_status=401, request=request)
|
|
|
|
request["user"] = data["sub"]
|
|
request["caller"] = "user"
|
|
except (jwt.exceptions.DecodeError, jwt.exceptions.InvalidSignatureError, jwt.exceptions.InvalidAlgorithmError):
|
|
return await api_response(status=-1, error={
|
|
"code": "token-invalid",
|
|
"target": "signature",
|
|
"message": "Invalid signature."
|
|
}, http_status=401, request=request)
|
|
except (jwt.exceptions.ExpiredSignatureError):
|
|
return await api_response(status=-1, error={
|
|
"code": "token-invalid",
|
|
"target": "expire",
|
|
"message": "Token expired."
|
|
}, http_status=401, request=request)
|
|
except Exception as e:
|
|
return await api_response(status=-1, error=str(e), http_status=500, request=request)
|
|
else:
|
|
return await api_response(status=-1, error={
|
|
"code": "missing-token",
|
|
"message": "Missing token."
|
|
}, http_status=401, request=request)
|
|
|
|
return await f(*args, **kwargs)
|
|
return async_wrapper(*args, **kwargs)
|
|
return decorated_function |