|
|
|
@ -9,21 +9,33 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
from transformers.trainer_utils import set_seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def demo_qwen_pretrain(args):
|
|
|
|
|
def _load_model_tokenizer(args):
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
args.checkpoint_path, trust_remote_code=True
|
|
|
|
|
args.checkpoint_path, trust_remote_code=True,
|
|
|
|
|
)
|
|
|
|
|
print("load tokenizer")
|
|
|
|
|
max_memory = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB"
|
|
|
|
|
|
|
|
|
|
n_gpus = torch.cuda.device_count()
|
|
|
|
|
max_memory = {i: max_memory for i in range(n_gpus)}
|
|
|
|
|
if args.cpu_only:
|
|
|
|
|
device_map = "cpu"
|
|
|
|
|
max_memory = None
|
|
|
|
|
else:
|
|
|
|
|
device_map = "auto"
|
|
|
|
|
max_memory_str = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB"
|
|
|
|
|
n_gpus = torch.cuda.device_count()
|
|
|
|
|
max_memory = {i: max_memory_str for i in range(n_gpus)}
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
args.checkpoint_path,
|
|
|
|
|
device_map="cuda:0",
|
|
|
|
|
device_map=device_map,
|
|
|
|
|
max_memory=max_memory,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
).eval()
|
|
|
|
|
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def demo_qwen_pretrain(args):
|
|
|
|
|
model, tokenizer = _load_model_tokenizer(args)
|
|
|
|
|
inputs = tokenizer(
|
|
|
|
|
"蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是",
|
|
|
|
|
return_tensors="pt",
|
|
|
|
@ -34,20 +46,7 @@ def demo_qwen_pretrain(args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def demo_qwen_chat(args):
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
args.checkpoint_path, trust_remote_code=True
|
|
|
|
|
)
|
|
|
|
|
print("load tokenizer")
|
|
|
|
|
max_memory = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB"
|
|
|
|
|
|
|
|
|
|
n_gpus = torch.cuda.device_count()
|
|
|
|
|
max_memory = {i: max_memory for i in range(n_gpus)}
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
args.checkpoint_path,
|
|
|
|
|
device_map="cuda:0",
|
|
|
|
|
max_memory=max_memory,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
).eval()
|
|
|
|
|
model, tokenizer = _load_model_tokenizer(args)
|
|
|
|
|
queries = [
|
|
|
|
|
"请问把大象关冰箱总共要几步?",
|
|
|
|
|
"1+3=?",
|
|
|
|
@ -65,16 +64,20 @@ def demo_qwen_chat(args):
|
|
|
|
|
print("Response:", response, end="\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
def main():
|
|
|
|
|
parser = argparse.ArgumentParser(description="Test HF checkpoint.")
|
|
|
|
|
parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path")
|
|
|
|
|
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
|
|
|
|
|
parser.add_argument("--gpu", type=int, default=0, help="gpu id")
|
|
|
|
|
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
set_seed(args.seed)
|
|
|
|
|
|
|
|
|
|
if 'chat' in args.checkpoint_path.lower():
|
|
|
|
|
if "chat" in args.checkpoint_path.lower():
|
|
|
|
|
demo_qwen_chat(args)
|
|
|
|
|
else:
|
|
|
|
|
demo_qwen_pretrain(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|
|
|
|
|