From eb6e364fe767f25ec824f4bfd174116929a06b7a Mon Sep 17 00:00:00 2001 From: yangapku Date: Mon, 25 Sep 2023 20:29:29 +0800 Subject: [PATCH] update cli_demo.py and web_demo.py --- cli_demo.py | 35 ++++++++++++++--------------------- web_demo.py | 51 ++++++++++++++++++++++----------------------------- 2 files changed, 36 insertions(+), 50 deletions(-) diff --git a/cli_demo.py b/cli_demo.py index f0ce3c7..51b1b13 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -18,8 +18,12 @@ from transformers.trainer_utils import set_seed DEFAULT_CKPT_PATH = 'Qwen/Qwen-7B-Chat' _WELCOME_MSG = '''\ -Welcome to use Qwen-7B-Chat model, type text to start chat, type :h to show command help -欢迎使用 Qwen-7B 模型,输入内容即可进行对话,:h 显示命令帮助 +Welcome to use Qwen-Chat model, type text to start chat, type :h to show command help. +(欢迎使用 Qwen-Chat 模型,输入内容即可进行对话,:h 显示命令帮助。) + +Note: This demo is governed by the original license of Qwen. +We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc. +(注:本演示受Qwen的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。) ''' _HELP_MSG = '''\ Commands: @@ -46,23 +50,12 @@ def _load_model_tokenizer(args): else: device_map = "auto" - qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json') - if os.path.exists(qconfig_path): - from auto_gptq import AutoGPTQForCausalLM - model = AutoGPTQForCausalLM.from_quantized( - args.checkpoint_path, - device_map=device_map, - trust_remote_code=True, - resume_download=True, - use_safetensors=True, - ).eval() - else: - model = AutoModelForCausalLM.from_pretrained( - args.checkpoint_path, - device_map=device_map, - trust_remote_code=True, - resume_download=True, - ).eval() + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + device_map=device_map, + trust_remote_code=True, + resume_download=True, + ).eval() config = GenerationConfig.from_pretrained( args.checkpoint_path, trust_remote_code=True, resume_download=True, @@ -103,7 +96,7 @@ def _get_input() -> str: def main(): parser = argparse.ArgumentParser( - description='QWen-7B-Chat command-line interactive chat demo.') + description='QWen-Chat command-line interactive chat demo.') parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, help="Checkpoint name or path, default to %(default)r") parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed") @@ -195,7 +188,7 @@ def main(): for response in model.chat_stream(tokenizer, query, history=history, generation_config=config): _clear_screen() print(f"\nUser: {query}") - print(f"\nQwen-7B: {response}") + print(f"\nQwen-Chat: {response}") except KeyboardInterrupt: print('[WARNING] Generation interrupted') continue diff --git a/web_demo.py b/web_demo.py index 11d7a56..8332433 100755 --- a/web_demo.py +++ b/web_demo.py @@ -47,23 +47,12 @@ def _load_model_tokenizer(args): else: device_map = "auto" - qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json') - if os.path.exists(qconfig_path): - from auto_gptq import AutoGPTQForCausalLM - model = AutoGPTQForCausalLM.from_quantized( - args.checkpoint_path, - device_map=device_map, - trust_remote_code=True, - resume_download=True, - use_safetensors=True, - ).eval() - else: - model = AutoModelForCausalLM.from_pretrained( - args.checkpoint_path, - device_map=device_map, - trust_remote_code=True, - resume_download=True, - ).eval() + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + device_map=device_map, + trust_remote_code=True, + resume_download=True, + ).eval() config = GenerationConfig.from_pretrained( args.checkpoint_path, trust_remote_code=True, resume_download=True, @@ -133,7 +122,7 @@ def _launch_demo(args, model, tokenizer, config): print(f"History: {_task_history}") _task_history.append((_query, full_response)) - print(f"Qwen-7B-Chat: {_parse_text(full_response)}") + print(f"Qwen-Chat: {_parse_text(full_response)}") def regenerate(_chatbot, _task_history): if not _task_history: @@ -156,21 +145,25 @@ def _launch_demo(args, model, tokenizer, config): with gr.Blocks() as demo: gr.Markdown("""\ -

""") - gr.Markdown("""

Qwen-7B-Chat Bot
""") +

""") + gr.Markdown("""

Qwen-Chat Bot
""") gr.Markdown( """\ -
This WebUI is based on Qwen-7B-Chat, developed by Alibaba Cloud. \ -(本WebUI基于Qwen-7B-Chat打造,实现聊天机器人功能。)
""") +
This WebUI is based on Qwen-Chat, developed by Alibaba Cloud. \ +(本WebUI基于Qwen-Chat打造,实现聊天机器人功能。)
""") gr.Markdown("""\ -
Qwen-7B 🤖 -| 🤗  | +
+Qwen-7B 🤖 | +🤗  | Qwen-7B-Chat 🤖 | 🤗  | - Github
""") +Qwen-14B 🤖 | +🤗  | +Qwen-14B-Chat 🤖 | +🤗  | + Github
""") - chatbot = gr.Chatbot(label='Qwen-7B-Chat', elem_classes="control-height") + chatbot = gr.Chatbot(label='Qwen-Chat', elem_classes="control-height") query = gr.Textbox(lines=2, label='Input') task_history = gr.State([]) @@ -185,10 +178,10 @@ Qwen-7B-Chat regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True) gr.Markdown("""\ -Note: This demo is governed by the original license of Qwen-7B. \ +Note: This demo is governed by the original license of Qwen. \ We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \ including hate speech, violence, pornography, deception, etc. \ -(注:本演示受Qwen-7B的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\ +(注:本演示受Qwen的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\ 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""") demo.queue().launch(