Merge pull request #606 from joindn/api-auth

add api-auth for openai_api.py
main
Jianxin Ma 1 year ago committed by GitHub
commit 54bca4fcd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,8 +19,29 @@ from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
import base64
class BasicAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, username: str, password: str):
super().__init__(app)
self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
async def dispatch(self, request: Request, call_next):
authorization: str = request.headers.get("Authorization")
if authorization:
try:
schema, credentials = authorization.split()
if credentials == self.required_credentials:
return await call_next(request)
except ValueError:
pass
headers = {'WWW-Authenticate': 'Basic'}
return Response(status_code=401, headers=headers)
def _gc(forced: bool = False): def _gc(forced: bool = False):
global args global args
if args.disable_gc and not forced: if args.disable_gc and not forced:
@ -482,6 +503,9 @@ def _get_args():
default="Qwen/Qwen-7B-Chat", default="Qwen/Qwen-7B-Chat",
help="Checkpoint name or path, default to %(default)r", help="Checkpoint name or path, default to %(default)r",
) )
parser.add_argument(
"--api-auth", help="API authentication credentials"
)
parser.add_argument( parser.add_argument(
"--cpu-only", action="store_true", help="Run demo with CPU only" "--cpu-only", action="store_true", help="Run demo with CPU only"
) )
@ -511,6 +535,11 @@ if __name__ == "__main__":
resume_download=True, resume_download=True,
) )
if args.api_auth:
app.add_middleware(
BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1]
)
if args.cpu_only: if args.cpu_only:
device_map = "cpu" device_map = "cpu"
else: else:

Loading…
Cancel
Save