from __future__ import annotations from functools import wraps import json from typing import Any, Optional, Dict from aiohttp import web import jwt import uuid from config import Config ParamRule = Dict[str, Any] class ParamInvalidException(web.HTTPBadRequest): def __init__(self, param_list: list[str], rules: dict[str, ParamRule]): self.code = "param_invalid" self.param_list = param_list self.rules = rules param_list_str = "'" + ("', '".join(param_list)) + "'" super().__init__(f"Param invalid: {param_list_str}", content_type="application/json", body=json.dumps({ "status": -1, "error": { "code": self.code, "message": f"Param invalid: {param_list_str}" } })) async def get_param(request: web.Request, rules: Optional[dict[str, ParamRule]] = None): params: dict[str, Any] = {} for key, value in request.match_info.items(): params[key] = value for key, value in request.query.items(): params[key] = value if request.method == 'POST': if request.headers.get('content-type') == 'application/json': data = await request.json() if data is not None and isinstance(data, dict): for key, value in data.items(): params[key] = value else: data = await request.post() for key, value in data.items(): params[key] = value if rules is not None: invalid_params: list[str] = [] for key, rule in rules.items(): if "required" in rule and rule["required"] and key not in params.keys(): invalid_params.append(key) continue if key in params: if "type" in rule: if rule["type"] is dict: if params[key] not in rule["type"]: invalid_params.append(key) continue try: if rule["type"] == int: params[key] = int(params[key]) elif rule["type"] == float: params[key] = float(params[key]) elif rule["type"] == bool: val = params[key] if isinstance(val, bool): params[key] = val elif isinstance(val, str): val = val.lower() if val.lower() == "false" or val == "0": params[key] = False else: params[key] = True elif isinstance(val, int): if val == 0: params[key] = False else: params[key] = True else: params[key] = True except ValueError: invalid_params.append(key) continue else: if "default" in rule: params[key] = rule["default"] else: params[key] = None if len(invalid_params) > 0: raise ParamInvalidException(invalid_params, rules) return params async def api_response(status, data=None, error=None, warning=None, http_status=200, request: Optional[web.Request] = None): ret = { "status": status } if data: ret["data"] = data if error: ret["error"] = error if warning: ret["warning"] = warning if request and is_websocket(request): ret["event"] = "response" ws = web.WebSocketResponse() await ws.prepare(request) await ws.send_json(ret) await ws.close() else: return web.json_response(ret, status=http_status) def is_websocket(request: web.Request): return request.headers.get('Upgrade', '').lower() == 'websocket' def generate_uuid(): return str(uuid.uuid4()) # Auth decorator def token_auth(f): @wraps(f) def decorated_function(*args, **kwargs): async def async_wrapper(*args, **kwargs): auth_tokens: dict = Config.get("authorization") request: web.Request = args[0] jwt_token = None sk_token = None params = await get_param(request) token = params.get("token") if token: jwt_token = token else: token: str = request.headers.get('Authorization') or request.headers.get('authorization') if token is None: return await api_response(status=-1, error={ "code": "missing-token", "message": "Missing token." }, http_status=401, request=request) token = token.replace("Bearer ", "") if token.startswith("sk-"): sk_token = token else: jwt_token = token if sk_token is not None: if sk_token not in auth_tokens.values(): return await api_response(status=-1, error={ "code": "token-invalid", "target": "token_id", "message": "Token invalid." }, http_status=401, request=request) if "user_id" in params: request["user"] = params["user_id"] else: request["user"] = 0 request["caller"] = "server" elif jwt_token is not None: # Get appid from jwt header try: jwt_header = jwt.get_unverified_header(jwt_token) key_id: str = jwt_header["kid"] except (KeyError, jwt.exceptions.DecodeError): return await api_response(status=-1, error={ "code": "token-invalid", "target": "token_id", "message": "Token issuer not exists." }, http_status=401, request=request) # Check jwt try: data = jwt.decode(jwt_token, auth_tokens[key_id], algorithms=['HS256', 'HS384', 'HS512']) if "sub" not in data: return await api_response(status=-1, error={ "code": "token-invalid", "target": "subject", "message": "Token subject invalid." }, http_status=401, request=request) request["user"] = data["sub"] request["caller"] = "user" except (jwt.exceptions.DecodeError, jwt.exceptions.InvalidSignatureError, jwt.exceptions.InvalidAlgorithmError): return await api_response(status=-1, error={ "code": "token-invalid", "target": "signature", "message": "Invalid signature." }, http_status=401, request=request) except (jwt.exceptions.ExpiredSignatureError): return await api_response(status=-1, error={ "code": "token-invalid", "target": "expire", "message": "Token expired." }, http_status=401, request=request) except Exception as e: return await api_response(status=-1, error=str(e), http_status=500, request=request) else: return await api_response(status=-1, error={ "code": "missing-token", "message": "Missing token." }, http_status=401, request=request) return await f(*args, **kwargs) return async_wrapper(*args, **kwargs) return decorated_function