diff --git a/demo.py b/demo.py deleted file mode 100644 index 1a32d00..0000000 --- a/demo.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) Alibaba Cloud. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import argparse -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.trainer_utils import set_seed - - -def _load_model_tokenizer(args): - tokenizer = AutoTokenizer.from_pretrained( - args.checkpoint_path, trust_remote_code=True, - ) - print("load tokenizer") - - if args.cpu_only: - device_map = "cpu" - max_memory = None - else: - device_map = "auto" - max_memory_str = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB" - n_gpus = torch.cuda.device_count() - max_memory = {i: max_memory_str for i in range(n_gpus)} - - model = AutoModelForCausalLM.from_pretrained( - args.checkpoint_path, - device_map=device_map, - max_memory=max_memory, - trust_remote_code=True, - ).eval() - - return model, tokenizer - - -def demo_qwen_pretrain(args): - model, tokenizer = _load_model_tokenizer(args) - inputs = tokenizer( - "蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是", - return_tensors="pt", - ) - inputs = inputs.to(model.device) - pred = model.generate(inputs=inputs["input_ids"]) - print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)) - - -def demo_qwen_chat(args): - model, tokenizer = _load_model_tokenizer(args) - queries = [ - "请问把大象关冰箱总共要几步?", - "1+3=?", - "请将下面这句话翻译为英文:在哪里跌倒就在哪里趴着", - ] - history = None - for turn_idx, query in enumerate(queries, start=1): - response, history = model.chat( - tokenizer, - query, - history=history, - ) - print(f"===== Turn {turn_idx} ====") - print("Query:", query, end="\n") - print("Response:", response, end="\n") - - -def main(): - parser = argparse.ArgumentParser(description="Test HF checkpoint.") - parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path") - parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed") - parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") - - args = parser.parse_args() - set_seed(args.seed) - - if "chat" in args.checkpoint_path.lower(): - demo_qwen_chat(args) - else: - demo_qwen_pretrain(args) - - -if __name__ == "__main__": - main()