better install shell based on issue comment

main
wsl-wy 1 year ago
parent ad66116fe5
commit 92c5c47a4c

@ -3,13 +3,49 @@
cd "$(dirname "$0")"
thisDir=$(pwd)
export INSTALL_DEPS=false
export INSTALL_FLASH_ATTN=false
declare -a PASS_THROUGH_ARGS=()
while [[ $# -gt 0 ]]; do
case "$1" in
-h | --help)
echo "Usage: $0 [-h|--help] [--install-deps] [--install-flash-attn]"
exit 0
;;
--install-deps)
export INSTALL_DEPS=true
shift
;;
--install-flash-attn)
export INSTALL_FLASH_ATTN=true
shift
;;
-)
shift
PASS_THROUGH_ARGS=($@)
break
;;
*)
echo "Unknown option: $1"
exit 1
;;
esac
done
echo "INSTALL_DEPS: $INSTALL_DEPS"
echo "INSTALL_FLASH_ATTN: $INSTALL_FLASH_ATTN"
echo "PASS_THROUGH_ARGS: ${PASS_THROUGH_ARGS[@]}"
function performInstall() {
set -e
pushd "$thisDir"
pip3 install -r requirements.txt
pip3 install gradio mdtex2html scipy
pip3 install gradio mdtex2html scipy argparse
if $INSTALL_FLASH_ATTN; then
if [[ ! -d flash-attention ]]; then
if ! git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention; then
echo "Clone flash-attention failed, please install it manually."
@ -22,24 +58,18 @@ function performInstall() {
pip3 install csrc/layer_norm &&
pip3 install csrc/rotary ||
echo "Install flash-attention failed, please install it manually."
fi
popd
}
echo "Starting WebUI..."
if ! python3 web_demo.py; then
echo "Run demo failed, install the deps and try again? (y/n)"
# auto perform install if in docker
if [[ -t 0 ]] && [[ -t 1 ]] && [[ ! -f "/.dockerenv" ]]; then
read doInstall
if ! python3 web_demo.py ${PASS_THROUGH_ARGS[@]}; then
if $INSTALL_DEPS; then
echo "Installing deps, and try again..."
performInstall && python3 web_demo.py ${PASS_THROUGH_ARGS[@]}
else
doInstall="y"
echo "Please install deps manually, or use --install-deps to install deps automatically."
fi
if ! [[ "$doInstall" =~ y|Y ]]; then
exit 1
fi
echo "Installing deps, and try again..."
performInstall && python3 web_demo.py
fi

@ -7,21 +7,14 @@ import gradio as gr
import mdtex2html
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from argparse import ArgumentParser
import sys
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True
)
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-7B-Chat",
device_map="auto",
offload_folder="offload",
trust_remote_code=True,
resume_download=True,
).eval()
model.generation_config = GenerationConfig.from_pretrained(
"Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", offload_folder="offload", trust_remote_code=True, resume_download=True).eval()
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True, resume_download=True)
if len(sys.argv) > 1 and sys.argv[1] == "--exit":
sys.exit(0)
@ -82,7 +75,7 @@ def predict(input, chatbot):
chatbot.append((parse_text(input), ""))
fullResponse = ""
for response in model.chat(tokenizer, input, history=task_history, stream=True):
for response in model.chat_stream(tokenizer, input, history=task_history):
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot
@ -108,9 +101,7 @@ with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
query = gr.Textbox(
show_label=False, placeholder="Input...", lines=10
).style(container=False)
query = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
@ -120,4 +111,18 @@ with gr.Blocks() as demo:
submitBtn.click(reset_user_input, [], [query])
emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True)
demo.queue().launch(share=False, inbrowser=True, server_port=80, server_name="0.0.0.0")
if len(sys.argv) > 1:
print("Call args:" + str(sys.argv))
parser = ArgumentParser()
parser.add_argument("--share", action="store_true", default=False)
parser.add_argument("--inbrowser", action="store_true", default=False)
parser.add_argument("--server_port", type=int, default=80)
parser.add_argument("--server_name", type=str, default="0.0.0.0")
args = parser.parse_args(sys.argv[1:])
print("Args:" + str(args))
print("Args:" + str(args))
demo.queue().launch(args)
else:
demo.queue().launch()

Loading…
Cancel
Save