diff --git a/demo.py b/demo.py index 38272df..1a32d00 100644 --- a/demo.py +++ b/demo.py @@ -9,21 +9,33 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.trainer_utils import set_seed -def demo_qwen_pretrain(args): +def _load_model_tokenizer(args): tokenizer = AutoTokenizer.from_pretrained( - args.checkpoint_path, trust_remote_code=True + args.checkpoint_path, trust_remote_code=True, ) print("load tokenizer") - max_memory = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB" - n_gpus = torch.cuda.device_count() - max_memory = {i: max_memory for i in range(n_gpus)} + if args.cpu_only: + device_map = "cpu" + max_memory = None + else: + device_map = "auto" + max_memory_str = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB" + n_gpus = torch.cuda.device_count() + max_memory = {i: max_memory_str for i in range(n_gpus)} + model = AutoModelForCausalLM.from_pretrained( args.checkpoint_path, - device_map="cuda:0", + device_map=device_map, max_memory=max_memory, trust_remote_code=True, ).eval() + + return model, tokenizer + + +def demo_qwen_pretrain(args): + model, tokenizer = _load_model_tokenizer(args) inputs = tokenizer( "蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是", return_tensors="pt", @@ -34,20 +46,7 @@ def demo_qwen_pretrain(args): def demo_qwen_chat(args): - tokenizer = AutoTokenizer.from_pretrained( - args.checkpoint_path, trust_remote_code=True - ) - print("load tokenizer") - max_memory = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB" - - n_gpus = torch.cuda.device_count() - max_memory = {i: max_memory for i in range(n_gpus)} - model = AutoModelForCausalLM.from_pretrained( - args.checkpoint_path, - device_map="cuda:0", - max_memory=max_memory, - trust_remote_code=True, - ).eval() + model, tokenizer = _load_model_tokenizer(args) queries = [ "请问把大象关冰箱总共要几步?", "1+3=?", @@ -65,16 +64,20 @@ def demo_qwen_chat(args): print("Response:", response, end="\n") -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser(description="Test HF checkpoint.") parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path") parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed") - parser.add_argument("--gpu", type=int, default=0, help="gpu id") + parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") args = parser.parse_args() set_seed(args.seed) - if 'chat' in args.checkpoint_path.lower(): + if "chat" in args.checkpoint_path.lower(): demo_qwen_chat(args) else: demo_qwen_pretrain(args) + + +if __name__ == "__main__": + main()