You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
7.2 KiB
Python

# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""A simple command-line interactive chat demo."""
import argparse
import os
import platform
import shutil
from copy import deepcopy
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers.trainer_utils import set_seed
DEFAULT_CKPT_PATH = 'Qwen/Qwen-7B-Chat'
_WELCOME_MSG = '''\
Welcome to use Qwen-7B-Chat model, type text to start chat, type :h to show command help
欢迎使用 Qwen-7B 模型,输入内容即可进行对话,:h 显示命令帮助
'''
_HELP_MSG = '''\
Commands:
:help / :h Show this help message 显示帮助信息
:exit / :quit / :q Exit the demo 退出Demo
:clear / :cl Clear screen 清屏
:clear-his / :clh Clear history 清除对话历史
:history / :his Show history 显示对话历史
:seed Show current random seed 显示当前随机种子
:seed <N> Set random seed to <N> 设置随机种子
:conf Show current generation config 显示生成配置
:conf <key>=<value> Change generation config 修改生成配置
:reset-conf Reset generation config 重置生成配置
'''
def _load_model_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True, resume_download=True,
)
if args.cpu_only:
device_map = "cpu"
else:
device_map = "auto"
qconfig_path = os.path.join(args.checkpoint_path, 'quantize_config.json')
if os.path.exists(qconfig_path):
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
use_safetensors=True,
).eval()
else:
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
resume_download=True,
).eval()
config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True, resume_download=True,
)
return model, tokenizer, config
def _clear_screen():
if platform.system() == "Windows":
os.system("cls")
else:
os.system("clear")
def _print_history(history):
terminal_width = shutil.get_terminal_size()[0]
print(f'History ({len(history)})'.center(terminal_width, '='))
for index, (query, response) in enumerate(history):
print(f'User[{index}]: {query}')
print(f'QWen[{index}]: {response}')
print('=' * terminal_width)
def _get_input() -> str:
while True:
try:
message = input('User> ').strip()
except UnicodeDecodeError:
print('[ERROR] Encoding error in input')
continue
except KeyboardInterrupt:
exit(1)
if message:
return message
print('[ERROR] Query is empty')
def main():
parser = argparse.ArgumentParser(
description='QWen-7B-Chat command-line interactive chat demo.')
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
help="Checkpoint name or path, default to %(default)r")
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
args = parser.parse_args()
history, response = [], ''
model, tokenizer, config = _load_model_tokenizer(args)
orig_gen_config = deepcopy(model.generation_config)
_clear_screen()
print(_WELCOME_MSG)
seed = args.seed
while True:
query = _get_input()
# Process commands.
if query.startswith(':'):
command_words = query[1:].strip().split()
if not command_words:
command = ''
else:
command = command_words[0]
if command in ['exit', 'quit', 'q']:
break
elif command in ['clear', 'cl']:
_clear_screen()
print(_WELCOME_MSG)
continue
elif command in ['clear-history', 'clh']:
print(f'[INFO] All {len(history)} history cleared')
history.clear()
continue
elif command in ['help', 'h']:
print(_HELP_MSG)
continue
elif command in ['history', 'his']:
_print_history(history)
continue
elif command in ['seed']:
if len(command_words) == 1:
print(f'[INFO] Current random seed: {seed}')
continue
else:
new_seed_s = command_words[1]
try:
new_seed = int(new_seed_s)
except ValueError:
print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number')
else:
print(f'[INFO] Random seed changed to {new_seed}')
seed = new_seed
continue
elif command in ['conf']:
if len(command_words) == 1:
print(model.generation_config)
else:
for key_value_pairs_str in command_words[1:]:
eq_idx = key_value_pairs_str.find('=')
if eq_idx == -1:
print('[WARNING] format: <key>=<value>')
continue
conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:]
try:
conf_value = eval(conf_value_str)
except Exception as e:
print(e)
continue
else:
print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}')
setattr(model.generation_config, conf_key, conf_value)
continue
elif command in ['reset-conf']:
print('[INFO] Reset generation config')
model.generation_config = deepcopy(orig_gen_config)
print(model.generation_config)
continue
else:
# As normal query.
pass
# Run chat.
set_seed(seed)
try:
for response in model.chat_stream(tokenizer, query, history=history, generation_config=config):
_clear_screen()
print(f"\nUser: {query}")
print(f"\nQwen-7B: {response}")
except KeyboardInterrupt:
print('[WARNING] Generation interrupted')
continue
history.append((query, response))
if __name__ == "__main__":
main()