You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

84 lines
2.5 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.trainer_utils import set_seed
def _load_model_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True,
)
print("load tokenizer")
if args.cpu_only:
device_map = "cpu"
max_memory = None
else:
device_map = "auto"
max_memory_str = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB"
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory_str for i in range(n_gpus)}
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
device_map=device_map,
max_memory=max_memory,
trust_remote_code=True,
).eval()
return model, tokenizer
def demo_qwen_pretrain(args):
model, tokenizer = _load_model_tokenizer(args)
inputs = tokenizer(
"蒙古国的首都是乌兰巴托Ulaanbaatar\n冰岛的首都是雷克雅未克Reykjavik\n埃塞俄比亚的首都是",
return_tensors="pt",
)
inputs = inputs.to(model.device)
pred = model.generate(inputs=inputs["input_ids"])
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
def demo_qwen_chat(args):
model, tokenizer = _load_model_tokenizer(args)
queries = [
"请问把大象关冰箱总共要几步?",
"1+3=?",
"请将下面这句话翻译为英文:在哪里跌倒就在哪里趴着",
]
history = None
for turn_idx, query in enumerate(queries, start=1):
response, history = model.chat(
tokenizer,
query,
history=history,
)
print(f"===== Turn {turn_idx} ====")
print("Query:", query, end="\n")
print("Response:", response, end="\n")
def main():
parser = argparse.ArgumentParser(description="Test HF checkpoint.")
parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path")
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
args = parser.parse_args()
set_seed(args.seed)
if "chat" in args.checkpoint_path.lower():
demo_qwen_chat(args)
else:
demo_qwen_pretrain(args)
if __name__ == "__main__":
main()