|
|
|
@ -19,8 +19,29 @@ from pydantic import BaseModel, Field
|
|
|
|
|
from sse_starlette.sse import EventSourceResponse
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
from transformers.generation import GenerationConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
from starlette.requests import Request
|
|
|
|
|
from starlette.responses import Response
|
|
|
|
|
import base64
|
|
|
|
|
|
|
|
|
|
class BasicAuthMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
def __init__(self, app, username: str, password: str):
|
|
|
|
|
super().__init__(app)
|
|
|
|
|
self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
|
|
|
|
|
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
|
|
authorization: str = request.headers.get("Authorization")
|
|
|
|
|
if authorization:
|
|
|
|
|
try:
|
|
|
|
|
schema, credentials = authorization.split()
|
|
|
|
|
if credentials == self.required_credentials:
|
|
|
|
|
return await call_next(request)
|
|
|
|
|
except ValueError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
headers = {'WWW-Authenticate': 'Basic'}
|
|
|
|
|
return Response(status_code=401, headers=headers)
|
|
|
|
|
|
|
|
|
|
def _gc(forced: bool = False):
|
|
|
|
|
global args
|
|
|
|
|
if args.disable_gc and not forced:
|
|
|
|
@ -475,6 +496,9 @@ def _get_args():
|
|
|
|
|
default="Qwen/Qwen-7B-Chat",
|
|
|
|
|
help="Checkpoint name or path, default to %(default)r",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--api-auth", help="API authentication credentials"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--cpu-only", action="store_true", help="Run demo with CPU only"
|
|
|
|
|
)
|
|
|
|
@ -504,6 +528,11 @@ if __name__ == "__main__":
|
|
|
|
|
resume_download=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if args.api_auth:
|
|
|
|
|
app.add_middleware(
|
|
|
|
|
BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if args.cpu_only:
|
|
|
|
|
device_map = "cpu"
|
|
|
|
|
else:
|
|
|
|
|