diff --git a/web_demo.py b/web_demo.py index 3bf485b..2a9c29b 100755 --- a/web_demo.py +++ b/web_demo.py @@ -10,14 +10,37 @@ from transformers.generation import GenerationConfig from argparse import ArgumentParser import sys +print("Call args:" + str(sys.argv)) +parser = ArgumentParser() +parser.add_argument("--share", action="store_true", default=False) +parser.add_argument("--inbrowser", action="store_true", default=False) +parser.add_argument("--server_port", type=int, default=80) +parser.add_argument("--server_name", type=str, default="0.0.0.0") +parser.add_argument("--exit", action="store_true", default=False) +parser.add_argument("--model_revision", type=str, default="") +args = parser.parse_args(sys.argv[1:]) +print("Args:" + str(args)) + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True) -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, resume_download=True).eval() +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen-7B-Chat", + device_map="auto", + trust_remote_code=True, + resume_download=True, + **{"revision": args.model_revision} if args.model_revision is not None and args.model_revision != "" else {}, +).eval() model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True) -if len(sys.argv) > 1 and sys.argv[1] == "--exit": - sys.exit(0) +if 'exit' in args: + if args.exit: + sys.exit(0) + else: + del args.exit + +if 'model_revision' in args: + del args.model_revision def postprocess(self, y): @@ -112,17 +135,6 @@ with gr.Blocks() as demo: emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True) if len(sys.argv) > 1: - - print("Call args:" + str(sys.argv)) - parser = ArgumentParser() - parser.add_argument("--share", action="store_true", default=False) - parser.add_argument("--inbrowser", action="store_true", default=False) - parser.add_argument("--server_port", type=int, default=80) - parser.add_argument("--server_name", type=str, default="0.0.0.0") - args = parser.parse_args(sys.argv[1:]) - print("Args:" + str(args)) - - print("Args:" + str(args)) - demo.queue().launch(args) + demo.queue().launch(**vars(args)) else: demo.queue().launch()