diff --git a/openai_api.py b/openai_api.py index fafd8f0..e9674bc 100644 --- a/openai_api.py +++ b/openai_api.py @@ -19,8 +19,29 @@ from pydantic import BaseModel, Field from sse_starlette.sse import EventSourceResponse from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.generation import GenerationConfig - - +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response +import base64 + +class BasicAuthMiddleware(BaseHTTPMiddleware): + def __init__(self, app, username: str, password: str): + super().__init__(app) + self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode() + + async def dispatch(self, request: Request, call_next): + authorization: str = request.headers.get("Authorization") + if authorization: + try: + schema, credentials = authorization.split() + if credentials == self.required_credentials: + return await call_next(request) + except ValueError: + pass + + headers = {'WWW-Authenticate': 'Basic'} + return Response(status_code=401, headers=headers) + def _gc(forced: bool = False): global args if args.disable_gc and not forced: @@ -482,6 +503,9 @@ def _get_args(): default="Qwen/Qwen-7B-Chat", help="Checkpoint name or path, default to %(default)r", ) + parser.add_argument( + "--api-auth", help="API authentication credentials" + ) parser.add_argument( "--cpu-only", action="store_true", help="Run demo with CPU only" ) @@ -511,6 +535,11 @@ if __name__ == "__main__": resume_download=True, ) + if args.api_auth: + app.add_middleware( + BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1] + ) + if args.cpu_only: device_map = "cpu" else: