You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

204 lines
8.3 KiB
Python

from __future__ import annotations
from functools import wraps
import json
from typing import Any, Optional, Dict
from aiohttp import web
import jwt
import uuid
from config import Config
ParamRule = Dict[str, Any]
class ParamInvalidException(web.HTTPBadRequest):
def __init__(self, param_list: list[str], rules: dict[str, ParamRule]):
self.code = "param_invalid"
self.param_list = param_list
self.rules = rules
param_list_str = "'" + ("', '".join(param_list)) + "'"
super().__init__(f"Param invalid: {param_list_str}",
content_type="application/json",
body=json.dumps({
"status": -1,
"error": {
"code": self.code,
"message": f"Param invalid: {param_list_str}"
}
}))
async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None):
params: dict[str, Any] = {}
for key, value in request.match_info.items():
params[key] = value
for key, value in request.query.items():
params[key] = value
if request.method == 'POST':
if request.headers.get('content-type') == 'application/json':
data = await request.json()
if data is not None and isinstance(data, dict):
for key, value in data.items():
params[key] = value
else:
data = await request.post()
for key, value in data.items():
params[key] = value
if rules is not None:
invalid_params: list[str] = []
for key, rule in rules.items():
if "required" in rule and rule["required"] and key not in params.keys():
invalid_params.append(key)
continue
if key in params:
if "type" in rule:
if rule["type"] is dict:
if params[key] not in rule["type"]:
invalid_params.append(key)
continue
try:
if rule["type"] == int:
params[key] = int(params[key])
elif rule["type"] == float:
params[key] = float(params[key])
elif rule["type"] == bool:
val = params[key]
if isinstance(val, bool):
params[key] = val
elif isinstance(val, str):
val = val.lower()
if val.lower() == "false" or val == "0":
params[key] = False
else:
params[key] = True
elif isinstance(val, int):
if val == 0:
params[key] = False
else:
params[key] = True
else:
params[key] = True
except ValueError:
invalid_params.append(key)
continue
else:
if "default" in rule:
params[key] = rule["default"]
else:
params[key] = None
if len(invalid_params) > 0:
raise ParamInvalidException(invalid_params, rules)
return params
async def api_response(status, data=None, error=None, warning=None, http_status=200, request: Optional[web.Request] = None):
ret = { "status": status }
if data:
ret["data"] = data
if error:
ret["error"] = error
if warning:
ret["warning"] = warning
if request and is_websocket(request):
ret["event"] = "response"
ws = web.WebSocketResponse()
await ws.prepare(request)
await ws.send_json(ret)
await ws.close()
else:
return web.json_response(ret, status=http_status)
def is_websocket(request: web.Request):
return request.headers.get('Upgrade', '').lower() == 'websocket'
def generate_uuid():
return str(uuid.uuid4())
# Auth decorator
def token_auth(f):
@wraps(f)
def decorated_function(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
auth_tokens: dict = Config.get("authorization")
request: web.Request = args[0]
jwt_token = None
sk_token = None
params = await get_param(request)
token = params.get("token")
if token:
jwt_token = token
else:
token: str = request.headers.get('Authorization') or request.headers.get('authorization')
if token is None:
return await api_response(status=-1, error={
"code": "missing-token",
"message": "Missing token."
}, http_status=401, request=request)
token = token.replace("Bearer ", "")
if token.startswith("sk-"):
sk_token = token
else:
jwt_token = token
if sk_token is not None:
if sk_token not in auth_tokens.values():
return await api_response(status=-1, error={
"code": "token-invalid",
"target": "token_id",
"message": "Token invalid."
}, http_status=401, request=request)
if "user_id" in params:
request["user"] = params["user_id"]
else:
request["user"] = 0
request["caller"] = "server"
elif jwt_token is not None:
# Get appid from jwt header
try:
jwt_header = jwt.get_unverified_header(jwt_token)
key_id: str = jwt_header["kid"]
except (KeyError, jwt.exceptions.DecodeError):
return await api_response(status=-1, error={
"code": "token-invalid",
"target": "token_id",
"message": "Token issuer not exists."
}, http_status=401, request=request)
# Check jwt
try:
data = jwt.decode(jwt_token, auth_tokens[key_id], algorithms=['HS256', 'HS384', 'HS512'])
if "sub" not in data:
return await api_response(status=-1, error={
"code": "token-invalid",
"target": "subject",
"message": "Token subject invalid."
}, http_status=401, request=request)
request["user"] = data["sub"]
request["caller"] = "user"
except (jwt.exceptions.DecodeError, jwt.exceptions.InvalidSignatureError, jwt.exceptions.InvalidAlgorithmError):
return await api_response(status=-1, error={
"code": "token-invalid",
"target": "signature",
"message": "Invalid signature."
}, http_status=401, request=request)
except (jwt.exceptions.ExpiredSignatureError):
return await api_response(status=-1, error={
"code": "token-invalid",
"target": "expire",
"message": "Token expired."
}, http_status=401, request=request)
except Exception as e:
return await api_response(status=-1, error=str(e), http_status=500, request=request)
else:
return await api_response(status=-1, error={
"code": "missing-token",
"message": "Missing token."
}, http_status=401, request=request)
return await f(*args, **kwargs)
return async_wrapper(*args, **kwargs)
return decorated_function