diff --git a/run_web_demo.sh b/run_web_demo.sh index f3f9a8b..17529f3 100755 --- a/run_web_demo.sh +++ b/run_web_demo.sh @@ -3,43 +3,73 @@ cd "$(dirname "$0")" thisDir=$(pwd) +export INSTALL_DEPS=false +export INSTALL_FLASH_ATTN=false + +declare -a PASS_THROUGH_ARGS=() + +while [[ $# -gt 0 ]]; do + case "$1" in + -h | --help) + echo "Usage: $0 [-h|--help] [--install-deps] [--install-flash-attn]" + exit 0 + ;; + --install-deps) + export INSTALL_DEPS=true + shift + ;; + --install-flash-attn) + export INSTALL_FLASH_ATTN=true + shift + ;; + -) + shift + PASS_THROUGH_ARGS=($@) + break + ;; + + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +echo "INSTALL_DEPS: $INSTALL_DEPS" +echo "INSTALL_FLASH_ATTN: $INSTALL_FLASH_ATTN" +echo "PASS_THROUGH_ARGS: ${PASS_THROUGH_ARGS[@]}" + function performInstall() { - set -e pushd "$thisDir" pip3 install -r requirements.txt - pip3 install gradio mdtex2html scipy + pip3 install gradio mdtex2html scipy argparse - if [[ ! -d flash-attention ]]; then - if ! git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention; then - echo "Clone flash-attention failed, please install it manually." - return 0 + if $INSTALL_FLASH_ATTN; then + if [[ ! -d flash-attention ]]; then + if ! git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention; then + echo "Clone flash-attention failed, please install it manually." + return 0 + fi fi + + cd flash-attention && + pip3 install . && + pip3 install csrc/layer_norm && + pip3 install csrc/rotary || + echo "Install flash-attention failed, please install it manually." fi - cd flash-attention && - pip3 install . && - pip3 install csrc/layer_norm && - pip3 install csrc/rotary || - echo "Install flash-attention failed, please install it manually." popd } echo "Starting WebUI..." -if ! python3 web_demo.py; then - echo "Run demo failed, install the deps and try again? (y/n)" - # auto perform install if in docker - if [[ -t 0 ]] && [[ -t 1 ]] && [[ ! -f "/.dockerenv" ]]; then - read doInstall +if ! python3 web_demo.py ${PASS_THROUGH_ARGS[@]}; then + if $INSTALL_DEPS; then + echo "Installing deps, and try again..." + performInstall && python3 web_demo.py ${PASS_THROUGH_ARGS[@]} else - doInstall="y" + echo "Please install deps manually, or use --install-deps to install deps automatically." fi - - if ! [[ "$doInstall" =~ y|Y ]]; then - exit 1 - fi - - echo "Installing deps, and try again..." - performInstall && python3 web_demo.py fi diff --git a/web_demo.py b/web_demo.py index 9f812cf..27a12a7 100755 --- a/web_demo.py +++ b/web_demo.py @@ -7,21 +7,14 @@ import gradio as gr import mdtex2html from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig +from argparse import ArgumentParser import sys -tokenizer = AutoTokenizer.from_pretrained( - "Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True -) -model = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen-7B-Chat", - device_map="auto", - offload_folder="offload", - trust_remote_code=True, - resume_download=True, -).eval() -model.generation_config = GenerationConfig.from_pretrained( - "Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True -) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True) + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", offload_folder="offload", trust_remote_code=True, resume_download=True).eval() + +model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True) if len(sys.argv) > 1 and sys.argv[1] == "--exit": sys.exit(0) @@ -82,7 +75,7 @@ def predict(input, chatbot): chatbot.append((parse_text(input), "")) fullResponse = "" - for response in model.chat(tokenizer, input, history=task_history, stream=True): + for response in model.chat_stream(tokenizer, input, history=task_history): chatbot[-1] = (parse_text(input), parse_text(response)) yield chatbot @@ -108,9 +101,7 @@ with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=4): with gr.Column(scale=12): - query = gr.Textbox( - show_label=False, placeholder="Input...", lines=10 - ).style(container=False) + query = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): @@ -120,4 +111,18 @@ with gr.Blocks() as demo: submitBtn.click(reset_user_input, [], [query]) emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True) -demo.queue().launch(share=False, inbrowser=True, server_port=80, server_name="0.0.0.0") +if len(sys.argv) > 1: + + print("Call args:" + str(sys.argv)) + parser = ArgumentParser() + parser.add_argument("--share", action="store_true", default=False) + parser.add_argument("--inbrowser", action="store_true", default=False) + parser.add_argument("--server_port", type=int, default=80) + parser.add_argument("--server_name", type=str, default="0.0.0.0") + args = parser.parse_args(sys.argv[1:]) + print("Args:" + str(args)) + + print("Args:" + str(args)) + demo.queue().launch(args) +else: + demo.queue().launch()