add api-auth

main
曾金栋 2 years ago
parent 99cacff46a
commit ee4b20f2fa

@ -19,7 +19,28 @@ from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
import base64
class BasicAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, username: str, password: str):
super().__init__(app)
self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
async def dispatch(self, request: Request, call_next):
authorization: str = request.headers.get("Authorization")
if authorization:
try:
schema, credentials = authorization.split()
if credentials == self.required_credentials:
return await call_next(request)
except ValueError:
pass
headers = {'WWW-Authenticate': 'Basic'}
return Response(status_code=401, headers=headers)
def _gc(forced: bool = False): def _gc(forced: bool = False):
global args global args
@ -475,6 +496,9 @@ def _get_args():
default="Qwen/Qwen-7B-Chat", default="Qwen/Qwen-7B-Chat",
help="Checkpoint name or path, default to %(default)r", help="Checkpoint name or path, default to %(default)r",
) )
parser.add_argument(
"--api-auth", help="API authentication credentials"
)
parser.add_argument( parser.add_argument(
"--cpu-only", action="store_true", help="Run demo with CPU only" "--cpu-only", action="store_true", help="Run demo with CPU only"
) )
@ -504,6 +528,11 @@ if __name__ == "__main__":
resume_download=True, resume_download=True,
) )
if args.api_auth:
app.add_middleware(
BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1]
)
if args.cpu_only: if args.cpu_only:
device_map = "cpu" device_map = "cpu"
else: else:

Loading…
Cancel
Save