You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
177 lines
7.1 KiB
Python
177 lines
7.1 KiB
Python
2 years ago
|
from __future__ import annotations
|
||
|
from functools import wraps
|
||
|
from typing import Any, Optional, Dict
|
||
|
from aiohttp import web
|
||
|
import jwt
|
||
|
import config
|
||
|
|
||
|
ParamRule = Dict[str, Any]
|
||
|
|
||
|
class ParamInvalidException(Exception):
|
||
|
def __init__(self, param_list: list[str], rules: dict[str, ParamRule]):
|
||
|
self.code = "param_invalid"
|
||
|
self.param_list = param_list
|
||
|
self.rules = rules
|
||
|
param_list_str = "'" + ("', '".join(param_list)) + "'"
|
||
|
super().__init__(f"Param invalid: {param_list_str}")
|
||
|
|
||
|
async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None):
|
||
|
params: dict[str, Any] = {}
|
||
|
for key, value in request.query.items():
|
||
|
params[key] = value
|
||
|
if request.method == 'POST':
|
||
|
if request.headers.get('content-type') == 'application/json':
|
||
|
data = await request.json()
|
||
|
if data is not None and data is dict:
|
||
|
for key, value in data.items():
|
||
|
params[key] = value
|
||
|
else:
|
||
|
data = await request.post()
|
||
|
for key, value in data.items():
|
||
|
params[key] = value
|
||
|
|
||
|
if rules is not None:
|
||
|
invalid_params: list[str] = []
|
||
|
for key, rule in rules.items():
|
||
|
if "required" in rule and rule["required"] and params[key] is None:
|
||
|
invalid_params.append(key)
|
||
|
continue
|
||
|
|
||
|
if key in params:
|
||
|
if "type" in rule:
|
||
|
if rule["type"] is dict:
|
||
|
if params[key] not in rule["type"]:
|
||
|
invalid_params.append(key)
|
||
|
continue
|
||
|
try:
|
||
|
if rule["type"] == int:
|
||
|
params[key] = int(params[key])
|
||
|
elif rule["type"] == float:
|
||
|
params[key] = float(params[key])
|
||
|
elif rule["type"] == bool:
|
||
|
val = params[key].lower()
|
||
|
if val == "false" or val == "0":
|
||
|
params[key] = False
|
||
|
else:
|
||
|
params[key] = True
|
||
|
except ValueError:
|
||
|
invalid_params.append(key)
|
||
|
continue
|
||
|
else:
|
||
|
if "default" in rule:
|
||
|
params[key] = rule["default"]
|
||
|
else:
|
||
|
params[key] = None
|
||
|
|
||
|
if len(invalid_params) > 0:
|
||
|
raise ParamInvalidException(invalid_params, rules)
|
||
|
|
||
|
return params
|
||
|
|
||
|
async def api_response(status, data=None, error=None, warning=None, http_status=200, request: Optional[web.Request] = None):
|
||
|
ret = { "status": status }
|
||
|
if data:
|
||
|
ret["data"] = data
|
||
|
if error:
|
||
|
ret["error"] = error
|
||
|
if warning:
|
||
|
ret["warning"] = warning
|
||
|
if request and is_websocket(request):
|
||
|
ret["event"] = "response"
|
||
|
ws = web.WebSocketResponse()
|
||
|
await ws.prepare(request)
|
||
|
await ws.send_json(ret)
|
||
|
await ws.close()
|
||
|
else:
|
||
|
return web.json_response(ret, status=http_status)
|
||
|
|
||
|
def is_websocket(request: web.Request):
|
||
|
return request.headers.get('Upgrade', '').lower() == 'websocket'
|
||
|
|
||
|
# Auth decorator
|
||
|
def token_auth(f):
|
||
|
@wraps(f)
|
||
|
def decorated_function(*args, **kwargs):
|
||
|
async def async_wrapper(*args, **kwargs):
|
||
|
request: web.Request = args[0]
|
||
|
|
||
|
jwt_token = None
|
||
|
sk_token = None
|
||
|
params = await get_param(request)
|
||
|
token = params.get("token")
|
||
|
if token:
|
||
|
jwt_token = token
|
||
|
else:
|
||
|
token: str = request.headers.get('Authorization')
|
||
|
if token is None:
|
||
|
return await api_response(status=-1, error={
|
||
|
"code": "missing-token",
|
||
|
"message": "Missing token."
|
||
|
}, http_status=401, request=request)
|
||
|
token = token.replace("Bearer ", "")
|
||
|
if token.startswith("sk_"):
|
||
|
sk_token = token
|
||
|
else:
|
||
|
jwt_token = token
|
||
|
|
||
|
if sk_token is not None:
|
||
|
if token not in config.AUTH_TOKENS:
|
||
|
return await api_response(status=-1, error={
|
||
|
"code": "token-invalid",
|
||
|
"target": "token_id",
|
||
|
"message": "Token invalid."
|
||
|
}, http_status=401, request=request)
|
||
|
|
||
|
if "user_id" in params:
|
||
|
request["user"] = params["user_id"]
|
||
|
else:
|
||
|
request["user"] = 0
|
||
|
|
||
|
request["caller"] = "server"
|
||
|
elif jwt_token is not None:
|
||
|
# Get appid from jwt header
|
||
|
try:
|
||
|
jwt_header = jwt.get_unverified_header(jwt_token)
|
||
|
key_id: str = jwt_header["kid"]
|
||
|
except (KeyError, jwt.exceptions.DecodeError):
|
||
|
return await api_response(status=-1, error={
|
||
|
"code": "token-invalid",
|
||
|
"target": "token_id",
|
||
|
"message": "Token issuer not exists."
|
||
|
}, http_status=401, request=request)
|
||
|
|
||
|
# Check jwt
|
||
|
try:
|
||
|
data = jwt.decode(jwt_token, config.AUTH_TOKENS[key_id], algorithms=['HS256', 'HS384', 'HS512'])
|
||
|
if "sub" not in data:
|
||
|
return await api_response(status=-1, error={
|
||
|
"code": "token-invalid",
|
||
|
"target": "subject",
|
||
|
"message": "Token subject invalid."
|
||
|
}, http_status=401, request=request)
|
||
|
|
||
|
request["user"] = data["sub"]
|
||
|
request["caller"] = "user"
|
||
|
except (jwt.exceptions.DecodeError, jwt.exceptions.InvalidSignatureError, jwt.exceptions.InvalidAlgorithmError):
|
||
|
return await api_response(status=-1, error={
|
||
|
"code": "token-invalid",
|
||
|
"target": "signature",
|
||
|
"message": "Invalid signature."
|
||
|
}, http_status=401, request=request)
|
||
|
except (jwt.exceptions.ExpiredSignatureError):
|
||
|
return await api_response(status=-1, error={
|
||
|
"code": "token-invalid",
|
||
|
"target": "expire",
|
||
|
"message": "Token expired."
|
||
|
}, http_status=401, request=request)
|
||
|
except Exception as e:
|
||
|
return await api_response(status=-1, error=str(e), http_status=500, request=request)
|
||
|
else:
|
||
|
return await api_response(status=-1, error={
|
||
|
"code": "missing-token",
|
||
|
"message": "Missing token."
|
||
|
}, http_status=401, request=request)
|
||
|
|
||
|
return await f(*args, **kwargs)
|
||
|
return async_wrapper(*args, **kwargs)
|
||
|
return decorated_function
|