import random import tqdm import os import re import sys import torch import numpy as np import jsonlines import argparse import jsonlines import datasets from datasets import load_from_disk,load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") INVALID_ANS = "[invalid]" def doc_to_text(doc): return fewshot_prompt + "\nQuestion: " + doc["question"] + "\nLet's think step by step\n" def decode(tokens_list, tokenizer, raw_text_len): sents = [] # print(len(tokens_list)) for tokens in tokens_list: tokens = tokens.cpu().numpy().tolist() sent = tokenizer.tokenizer.decode( tokens[raw_text_len:]) sent = sent.split('<|endoftext|>')[0] sent = sent.split('\n\n\n')[0] sent = sent.split("\n\n")[0] sent = sent.split("Question:")[0] sents.append(sent) return sents def generate_sample(model, tokenizer, input_txt): input_ids = tokenizer.tokenizer.encode(input_txt) raw_text_len = len(input_ids) context_enc = torch.tensor( [input_ids]).to(model.device) print(f"Input text: {input_txt}\n") outputs = model.generate(context_enc) output_text = decode(outputs,tokenizer,raw_text_len)[0] print(f"\nOutput text: {output_text}\n") return output_text def extract_answer_hf(completion): match = ANS_RE.search(completion) if match: match_str = match.group(1).strip() match_str = match_str.replace(",", "") return eval(match_str) else: return INVALID_ANS def extract_answer(completion): try: last_number = re.findall(r'\d+', completion)[-1] return eval(last_number) except: return INVALID_ANS def is_correct( completion, answer): gold = extract_answer_hf(answer) assert gold != INVALID_ANS, "No ground truth answer found in the document." return extract_answer(completion) == gold if __name__ == '__main__': parser = argparse.ArgumentParser(description='Test HF checkpoint.') parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path", default="Qwen/Qwen-7B") parser.add_argument("-f","--sample-input-file", type=str, default=None) parser.add_argument("-o","--sample-output-file", type=str, default="gsm8k_res.jsonl") args = parser.parse_args() fewshot_prompt = open("gsm8k_prompt.txt").read() if args.sample_input_file is not None: dataset = load_from_disk(args.sample_input_file) else: config = datasets.DownloadConfig(resume_download=True, max_retries=100) dataset = load_dataset("gsm8k", 'main', download_config=config) test = dataset["test"] print('Loading tokenizer ...') tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True) print('Loading model ...') model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval() model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True) model.generation_config.do_sample = False f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8')) tot_length = test.num_rows acc_res = [] for doc in test: context = doc_to_text(doc) completion = generate_sample(model, tokenizer, context) answer= doc["answer"] acc = is_correct(completion, answer) doc["completion"]=completion doc["acc"]=acc f_output.write(doc) acc_res.append(acc) f_output.close() print("Acc: ",np.mean(acc_res))