better install shell based on issue comment

main
wsl-wy 2 years ago
parent ad66116fe5
commit 92c5c47a4c

@ -3,43 +3,73 @@
cd "$(dirname "$0")" cd "$(dirname "$0")"
thisDir=$(pwd) thisDir=$(pwd)
export INSTALL_DEPS=false
export INSTALL_FLASH_ATTN=false
declare -a PASS_THROUGH_ARGS=()
while [[ $# -gt 0 ]]; do
case "$1" in
-h | --help)
echo "Usage: $0 [-h|--help] [--install-deps] [--install-flash-attn]"
exit 0
;;
--install-deps)
export INSTALL_DEPS=true
shift
;;
--install-flash-attn)
export INSTALL_FLASH_ATTN=true
shift
;;
-)
shift
PASS_THROUGH_ARGS=($@)
break
;;
*)
echo "Unknown option: $1"
exit 1
;;
esac
done
echo "INSTALL_DEPS: $INSTALL_DEPS"
echo "INSTALL_FLASH_ATTN: $INSTALL_FLASH_ATTN"
echo "PASS_THROUGH_ARGS: ${PASS_THROUGH_ARGS[@]}"
function performInstall() { function performInstall() {
set -e
pushd "$thisDir" pushd "$thisDir"
pip3 install -r requirements.txt pip3 install -r requirements.txt
pip3 install gradio mdtex2html scipy pip3 install gradio mdtex2html scipy argparse
if [[ ! -d flash-attention ]]; then if $INSTALL_FLASH_ATTN; then
if ! git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention; then if [[ ! -d flash-attention ]]; then
echo "Clone flash-attention failed, please install it manually." if ! git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention; then
return 0 echo "Clone flash-attention failed, please install it manually."
return 0
fi
fi fi
cd flash-attention &&
pip3 install . &&
pip3 install csrc/layer_norm &&
pip3 install csrc/rotary ||
echo "Install flash-attention failed, please install it manually."
fi fi
cd flash-attention &&
pip3 install . &&
pip3 install csrc/layer_norm &&
pip3 install csrc/rotary ||
echo "Install flash-attention failed, please install it manually."
popd popd
} }
echo "Starting WebUI..." echo "Starting WebUI..."
if ! python3 web_demo.py; then if ! python3 web_demo.py ${PASS_THROUGH_ARGS[@]}; then
echo "Run demo failed, install the deps and try again? (y/n)" if $INSTALL_DEPS; then
# auto perform install if in docker echo "Installing deps, and try again..."
if [[ -t 0 ]] && [[ -t 1 ]] && [[ ! -f "/.dockerenv" ]]; then performInstall && python3 web_demo.py ${PASS_THROUGH_ARGS[@]}
read doInstall
else else
doInstall="y" echo "Please install deps manually, or use --install-deps to install deps automatically."
fi fi
if ! [[ "$doInstall" =~ y|Y ]]; then
exit 1
fi
echo "Installing deps, and try again..."
performInstall && python3 web_demo.py
fi fi

@ -7,21 +7,14 @@ import gradio as gr
import mdtex2html import mdtex2html
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from argparse import ArgumentParser
import sys import sys
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True)
"Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True
) model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", offload_folder="offload", trust_remote_code=True, resume_download=True).eval()
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-7B-Chat", model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True)
device_map="auto",
offload_folder="offload",
trust_remote_code=True,
resume_download=True,
).eval()
model.generation_config = GenerationConfig.from_pretrained(
"Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True
)
if len(sys.argv) > 1 and sys.argv[1] == "--exit": if len(sys.argv) > 1 and sys.argv[1] == "--exit":
sys.exit(0) sys.exit(0)
@ -82,7 +75,7 @@ def predict(input, chatbot):
chatbot.append((parse_text(input), "")) chatbot.append((parse_text(input), ""))
fullResponse = "" fullResponse = ""
for response in model.chat(tokenizer, input, history=task_history, stream=True): for response in model.chat_stream(tokenizer, input, history=task_history):
chatbot[-1] = (parse_text(input), parse_text(response)) chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot yield chatbot
@ -108,9 +101,7 @@ with gr.Blocks() as demo:
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
with gr.Column(scale=12): with gr.Column(scale=12):
query = gr.Textbox( query = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
show_label=False, placeholder="Input...", lines=10
).style(container=False)
with gr.Column(min_width=32, scale=1): with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary") submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1): with gr.Column(scale=1):
@ -120,4 +111,18 @@ with gr.Blocks() as demo:
submitBtn.click(reset_user_input, [], [query]) submitBtn.click(reset_user_input, [], [query])
emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True) emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True)
demo.queue().launch(share=False, inbrowser=True, server_port=80, server_name="0.0.0.0") if len(sys.argv) > 1:
print("Call args:" + str(sys.argv))
parser = ArgumentParser()
parser.add_argument("--share", action="store_true", default=False)
parser.add_argument("--inbrowser", action="store_true", default=False)
parser.add_argument("--server_port", type=int, default=80)
parser.add_argument("--server_name", type=str, default="0.0.0.0")
args = parser.parse_args(sys.argv[1:])
print("Args:" + str(args))
print("Args:" + str(args))
demo.queue().launch(args)
else:
demo.queue().launch()

Loading…
Cancel
Save