From ee69a08a7327c5b07dc7caccc5a52af3fa0c0553 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8B=8F=E9=98=B3?= Date: Sat, 5 Aug 2023 20:47:58 +0800 Subject: [PATCH] Improve streaming chat demo. --- cli_demo.py | 203 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 162 insertions(+), 41 deletions(-) diff --git a/cli_demo.py b/cli_demo.py index 2ef0a11..3e97735 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -1,68 +1,189 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""A simple command-line interactive chat demo.""" + +import argparse import os import platform -import signal +import shutil +from copy import deepcopy + from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig +from transformers.trainer_utils import set_seed -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) -# We recommend checking the support of BF16 first. Run the command below: -# import torch -# torch.cuda.is_bf16_supported() -# use bf16 -# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval() -# use fp16 -# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval() -# use cpu only -# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval() -# use fp32 -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval() -model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", - trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参 +DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat' -stop_stream = False +_WELCOME_MSG = '''\ +Welcome to use Qwen-7B-Chat model, type text to start chat, type :h to show command help +欢迎使用 Qwen-7B 模型,输入内容即可进行对话,:h 显示命令帮助 +''' +_HELP_MSG = '''\ +Commands: + :help / :h Show this help message 显示帮助信息 + :exit / :quit / :q Exit the demo 退出Demo + :clear / :cl Clear screen 清屏 + :clear-his / :clh Clear history 清除对话历史 + :history / :his Show history 显示对话历史 + :seed Show current random seed 显示当前随机种子 + :seed Set random seed to 设置随机种子 + :conf Show current generation config 显示生成配置 + :conf = Change generation config 修改生成配置 + :reset-conf Reset generation config 重置生成配置 +''' -def signal_handler(signal, frame): - global stop_stream - stop_stream = True +def _load_model_tokenizer(args): + tokenizer = AutoTokenizer.from_pretrained( + args.checkpoint_path, trust_remote_code=True, + ) + + if args.cpu_only: + device_map = "cpu" + else: + device_map = "auto" + + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + device_map=device_map, + trust_remote_code=True, + ).eval() + model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True) + return model, tokenizer -def clear_screen(): + +def _clear_screen(): if platform.system() == "Windows": os.system("cls") else: os.system("clear") -def print_history(history): - for pair in history: - print(f"\nUser:{pair[0]}\nQwen-7B:{pair[1]}") +def _print_history(history): + terminal_width = shutil.get_terminal_size()[0] + print(f'History ({len(history)})'.center(terminal_width, '=')) + for index, (query, response) in enumerate(history): + print(f'User[{index}]: {query}') + print(f'QWen[{index}]: {response}') + print('=' * terminal_width) + + +def _get_input() -> str: + while True: + try: + message = input('User> ').strip() + except UnicodeDecodeError: + print('[ERROR] Encoding error in input') + continue + except KeyboardInterrupt: + exit(1) + if message: + return message + print('[ERROR] Query is empty') def main(): + parser = argparse.ArgumentParser( + description='QWen-7B-Chat command-line interactive chat demo.') + parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, + help="Checkpoint name or path, default to %(default)r") + parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed") + parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") + args = parser.parse_args() + history, response = [], '' - global stop_stream - clear_screen() - print("欢迎使用 Qwen-7B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + + model, tokenizer = _load_model_tokenizer(args) + orig_gen_config = deepcopy(model.generation_config) + + _clear_screen() + print(_WELCOME_MSG) + + seed = args.seed + while True: - query = input("\nUser:") - if query.strip() == "stop": - break - if query.strip() == "clear": - history = [] - clear_screen() - print("欢迎使用 Qwen-7B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") - continue - for response in model.chat(tokenizer, query, history=history, stream=True): - if stop_stream: - stop_stream = False + query = _get_input() + + # Process commands. + if query.startswith(':'): + command_words = query[1:].strip().split() + if not command_words: + command = '' + else: + command = command_words[0] + + if command in ['exit', 'quit', 'q']: break + elif command in ['clear', 'cl']: + _clear_screen() + print(_WELCOME_MSG) + continue + elif command in ['clear-history', 'clh']: + print(f'[INFO] All {len(history)} history cleared') + history.clear() + continue + elif command in ['help', 'h']: + print(_HELP_MSG) + continue + elif command in ['history', 'his']: + _print_history(history) + continue + elif command in ['seed']: + if len(command_words) == 1: + print(f'[INFO] Current random seed: {seed}') + continue + else: + new_seed_s = command_words[1] + try: + new_seed = int(new_seed_s) + except ValueError: + print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number') + else: + print(f'[INFO] Random seed changed to {new_seed}') + seed = new_seed + continue + elif command in ['conf']: + if len(command_words) == 1: + print(model.generation_config) + else: + for key_value_pairs_str in command_words[1:]: + eq_idx = key_value_pairs_str.find('=') + if eq_idx == -1: + print('[WARNING] format: =') + continue + conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:] + try: + conf_value = eval(conf_value_str) + except Exception as e: + print(e) + continue + else: + print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}') + setattr(model.generation_config, conf_key, conf_value) + continue + elif command in ['reset-conf']: + print('[INFO] Reset generation config') + model.generation_config = deepcopy(orig_gen_config) + print(model.generation_config) + continue else: - clear_screen() - print_history(history) + # As normal query. + pass + + # Run chat. + set_seed(seed) + try: + for response in model.chat(tokenizer, query, history=history, stream=True): + _clear_screen() print(f"\nUser: {query}") - print("\nQwen-7B:", end="") - print(response) + print(f"\nQwen-7B: {response}") + except KeyboardInterrupt: + print('[WARNING] Generation interrupted') + continue history.append((query, response))