add api-auth

main
曾金栋 1 year ago
parent 99cacff46a
commit ee4b20f2fa

@ -19,8 +19,29 @@ from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
import base64
class BasicAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, username: str, password: str):
super().__init__(app)
self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
async def dispatch(self, request: Request, call_next):
authorization: str = request.headers.get("Authorization")
if authorization:
try:
schema, credentials = authorization.split()
if credentials == self.required_credentials:
return await call_next(request)
except ValueError:
pass
headers = {'WWW-Authenticate': 'Basic'}
return Response(status_code=401, headers=headers)
def _gc(forced: bool = False):
global args
if args.disable_gc and not forced:
@ -475,6 +496,9 @@ def _get_args():
default="Qwen/Qwen-7B-Chat",
help="Checkpoint name or path, default to %(default)r",
)
parser.add_argument(
"--api-auth", help="API authentication credentials"
)
parser.add_argument(
"--cpu-only", action="store_true", help="Run demo with CPU only"
)
@ -504,6 +528,11 @@ if __name__ == "__main__":
resume_download=True,
)
if args.api_auth:
app.add_middleware(
BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1]
)
if args.cpu_only:
device_map = "cpu"
else:

Loading…
Cancel
Save