From 2e5cff645445e5b8e6060ad8f77d5869ba73d55e Mon Sep 17 00:00:00 2001 From: yangapku Date: Thu, 10 Aug 2023 21:47:23 +0800 Subject: [PATCH] update web_demo --- README.md | 2 +- README_CN.md | 2 +- README_JA.md | 2 +- cli_demo.py | 8 +- web_demo.py | 223 +++++++++++++++++++++++++++++---------------------- 5 files changed, 133 insertions(+), 104 deletions(-) diff --git a/README.md b/README.md index dec6cf4..98bab0a 100644 --- a/README.md +++ b/README.md @@ -257,7 +257,7 @@ python cli_demo.py We provide code for users to build a web UI demo (thanks to @wysiad). Before you start, make sure you install the following packages: ``` -pip install gradio mdtex2html +pip install -r requirements_web_demo.txt ``` Then run the command below and click on the generated link: diff --git a/README_CN.md b/README_CN.md index 78077a7..a034c90 100644 --- a/README_CN.md +++ b/README_CN.md @@ -259,7 +259,7 @@ python cli_demo.py 我们提供了Web UI的demo供用户使用 (感谢 @wysiad 支持)。在开始前,确保已经安装如下代码库: ``` -pip install gradio mdtex2html +pip install -r requirements_web_demo.txt ``` 随后运行如下命令,并点击生成链接: diff --git a/README_JA.md b/README_JA.md index bc51ab3..069d6da 100644 --- a/README_JA.md +++ b/README_JA.md @@ -264,7 +264,7 @@ python cli_demo.py ウェブUIデモを構築するためのコードを提供します(@wysiadに感謝)。始める前に、以下のパッケージがインストールされていることを確認してください: ``` -pip install gradio mdtex2html +pip install -r requirements_web_demo.txt ``` そして、以下のコマンドを実行し、生成されたリンクをクリックする: diff --git a/cli_demo.py b/cli_demo.py index 52a3a63..4a095ad 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -38,7 +38,7 @@ Commands: def _load_model_tokenizer(args): tokenizer = AutoTokenizer.from_pretrained( - args.checkpoint_path, trust_remote_code=True, + args.checkpoint_path, trust_remote_code=True, resume_download=True, ) if args.cpu_only: @@ -50,9 +50,11 @@ def _load_model_tokenizer(args): args.checkpoint_path, device_map=device_map, trust_remote_code=True, + resume_download=True, ).eval() - model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True) - + model.generation_config = GenerationConfig.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) return model, tokenizer diff --git a/web_demo.py b/web_demo.py index 5b3a2af..89cb28e 100755 --- a/web_demo.py +++ b/web_demo.py @@ -1,52 +1,60 @@ -#!/usr/bin/env python3 +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""A simple web interactive chat demo based on gradio.""" + +from argparse import ArgumentParser -from transformers import AutoTokenizer import gradio as gr import mdtex2html from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig -from argparse import ArgumentParser -import sys - -print("Call args:" + str(sys.argv)) -parser = ArgumentParser() -parser.add_argument("--share", action="store_true", default=False) -parser.add_argument("--inbrowser", action="store_true", default=False) -parser.add_argument("--server_port", type=int, default=80) -parser.add_argument("--server_name", type=str, default="0.0.0.0") -parser.add_argument("--exit", action="store_true", default=False) -parser.add_argument("--model_revision", type=str, default="") -args = parser.parse_args(sys.argv[1:]) -print("Args:" + str(args)) - -tokenizer = AutoTokenizer.from_pretrained( - "Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True -) - -model = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen-7B-Chat", - device_map="auto", - trust_remote_code=True, - resume_download=True, - **{"revision": args.model_revision} - if args.model_revision is not None - and args.model_revision != "" - and args.model_revision != "None" - else {}, -).eval() - -model.generation_config = GenerationConfig.from_pretrained( - "Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True -) - -if "exit" in args: - if args.exit: - sys.exit(0) + +DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat' + + +def _get_args(): + parser = ArgumentParser() + parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, + help="Checkpoint name or path, default to %(default)r") + parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") + + parser.add_argument("--share", action="store_true", default=False, + help="Create a publicly shareable link for the interface.") + parser.add_argument("--inbrowser", action="store_true", default=False, + help="Automatically launch the interface in a new tab on the default browser.") + parser.add_argument("--server-port", type=int, default=8000, + help="Demo server port.") + parser.add_argument("--server-name", type=str, default="127.0.0.1", + help="Demo server name.") + + args = parser.parse_args() + return args + + +def _load_model_tokenizer(args): + tokenizer = AutoTokenizer.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) + + if args.cpu_only: + device_map = "cpu" else: - del args.exit + device_map = "auto" + + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + device_map=device_map, + trust_remote_code=True, + resume_download=True, + ).eval() + model.generation_config = GenerationConfig.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) -if "model_revision" in args: - del args.model_revision + return model, tokenizer def postprocess(self, y): @@ -54,7 +62,7 @@ def postprocess(self, y): return [] for i, (message, response) in enumerate(y): y[i] = ( - None if message is None else mdtex2html.convert((message)), + None if message is None else mdtex2html.convert(message), None if response is None else mdtex2html.convert(response), ) return y @@ -63,7 +71,7 @@ def postprocess(self, y): gr.Chatbot.postprocess = postprocess -def parse_text(text): +def _parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 @@ -78,7 +86,7 @@ def parse_text(text): else: if i > 0: if count % 2 == 1: - line = line.replace("`", "\`") + line = line.replace("`", r"\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") @@ -95,70 +103,89 @@ def parse_text(text): return text -task_history = [] - - -def predict(query, chatbot): - print("User: " + parse_text(query)) - chatbot.append((parse_text(query), "")) - fullResponse = "" - - for response in model.chat_stream(tokenizer, query, history=task_history): - chatbot[-1] = (parse_text(query), parse_text(response)) - - yield chatbot - fullResponse = parse_text(response) +def _launch_demo(args, model, tokenizer): + task_history = [] - task_history.append((query, fullResponse)) - print("Qwen-7B-Chat: " + parse_text(fullResponse)) + def predict(_query, _chatbot): + print("User: " + _parse_text(_query)) + _chatbot.append((_parse_text(_query), "")) + full_response = "" + for response in model.chat_stream(tokenizer, _query, history=task_history): + _chatbot[-1] = (_parse_text(_query), _parse_text(response)) -def regenerate(chatbot): - if not task_history: - yield chatbot - return - item = task_history.pop(-1) - chatbot.pop(-1) - yield from predict(item[0], chatbot) + yield _chatbot + full_response = _parse_text(response) + task_history.append((_query, full_response)) + print("Qwen-7B-Chat: " + _parse_text(full_response)) -def reset_user_input(): - return gr.update(value="") + def regenerate(_chatbot): + if not task_history: + yield _chatbot + return + item = task_history.pop(-1) + _chatbot.pop(-1) + yield from predict(item[0], _chatbot) + def reset_user_input(): + return gr.update(value="") -def reset_state(): - task_history.clear() - return [] - + def reset_state(): + task_history.clear() + return [] -with gr.Blocks() as demo: - gr.Markdown("""

""") - gr.Markdown("""

Qwen-7B-Chat Bot
""") - gr.Markdown( - """
This WebUI is based on Qwen-7B-Chat, developed by Alibaba Cloud. (本WebUI基于Qwen-7B-Chat打造,实现聊天机器人功能。)
""" - ) - gr.Markdown( - """
Qwen-7B 🤖 | 🤗  | Qwen-7B-Chat 🤖 | 🤗  |  Github
""" + with gr.Blocks() as demo: + gr.Markdown("""\ +

""") + gr.Markdown("""

Qwen-7B-Chat Bot
""") + gr.Markdown( + """\ +
This WebUI is based on Qwen-7B-Chat, developed by Alibaba Cloud. \ +(本WebUI基于Qwen-7B-Chat打造,实现聊天机器人功能。)
""") + gr.Markdown("""\ +
Qwen-7B 🤖 +| 🤗  | +Qwen-7B-Chat 🤖 | +🤗  | + Github
""") + + chatbot = gr.Chatbot(label='Qwen-7B-Chat', elem_classes="control-height") + query = gr.Textbox(lines=2, label='Input') + + with gr.Row(): + empty_btn = gr.Button("🧹 Clear History (清除历史)") + submit_btn = gr.Button("🚀 Submit (发送)") + regen_btn = gr.Button("🤔️ Regenerate (重试)") + + submit_btn.click(predict, [query, chatbot], [chatbot], show_progress=True) + submit_btn.click(reset_user_input, [], [query]) + empty_btn.click(reset_state, outputs=[chatbot], show_progress=True) + regen_btn.click(regenerate, [chatbot], [chatbot], show_progress=True) + + gr.Markdown("""\ +Note: This demo is governed by the original license of Qwen-7B. \ +We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \ +including hate speech, violence, pornography, deception, etc. \ +(注:本演示受Qwen-7B的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\ +包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""") + + demo.queue().launch( + share=args.share, + inbrowser=args.inbrowser, + server_port=args.server_port, + server_name=args.server_name, ) - chatbot = gr.Chatbot(lines=10, label='Qwen-7B-Chat', elem_classes="control-height") - query = gr.Textbox(lines=2, label='Input') - with gr.Row(): - emptyBtn = gr.Button("🧹 Clear History (清除历史)") - submitBtn = gr.Button("🚀 Submit (发送)") - regenBtn = gr.Button("🤔️ Regenerate (重试)") +def main(): + args = _get_args() - submitBtn.click(predict, [query, chatbot], [chatbot], show_progress=True) - submitBtn.click(reset_user_input, [], [query]) - emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True) - regenBtn.click(regenerate, [chatbot], [chatbot], show_progress=True) + model, tokenizer = _load_model_tokenizer(args) + + _launch_demo(args, model, tokenizer) - gr.Markdown( - """Note: This demo is governed by the original license of Qwen-7B. We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc. (注:本演示受Qwen-7B的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""" - ) -if len(sys.argv) > 1: - demo.queue().launch(**vars(args)) -else: - demo.queue().launch() \ No newline at end of file +if __name__ == '__main__': + main()