From 47d984e90eae3789db8efc8034682c0a82def761 Mon Sep 17 00:00:00 2001 From: Iurnem Date: Fri, 4 Aug 2023 22:21:22 +0800 Subject: [PATCH] Add files via upload --- cli_demo.py | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 cli_demo.py diff --git a/cli_demo.py b/cli_demo.py new file mode 100644 index 0000000..2ef0a11 --- /dev/null +++ b/cli_demo.py @@ -0,0 +1,71 @@ +import os +import platform +import signal +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) +# We recommend checking the support of BF16 first. Run the command below: +# import torch +# torch.cuda.is_bf16_supported() +# use bf16 +# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval() +# use fp16 +# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval() +# use cpu only +# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval() +# use fp32 +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval() +model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", + trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参 + +stop_stream = False + + +def signal_handler(signal, frame): + global stop_stream + stop_stream = True + + +def clear_screen(): + if platform.system() == "Windows": + os.system("cls") + else: + os.system("clear") + + +def print_history(history): + for pair in history: + print(f"\nUser:{pair[0]}\nQwen-7B:{pair[1]}") + + +def main(): + history, response = [], '' + global stop_stream + clear_screen() + print("欢迎使用 Qwen-7B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + while True: + query = input("\nUser:") + if query.strip() == "stop": + break + if query.strip() == "clear": + history = [] + clear_screen() + print("欢迎使用 Qwen-7B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + continue + for response in model.chat(tokenizer, query, history=history, stream=True): + if stop_stream: + stop_stream = False + break + else: + clear_screen() + print_history(history) + print(f"\nUser: {query}") + print("\nQwen-7B:", end="") + print(response) + + history.append((query, response)) + + +if __name__ == "__main__": + main()