You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
3.7 KiB
Python

1 year ago
import re
import torch
import argparse
import jsonlines
import numpy as np
1 year ago
import datasets
from datasets import load_from_disk, load_dataset
1 year ago
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"
1 year ago
def doc_to_text(doc):
return (
fewshot_prompt
+ "\nQuestion: "
+ doc["question"]
+ "\nLet's think step by step\n"
)
1 year ago
def decode(tokens_list, tokenizer, raw_text_len):
sents = []
# print(len(tokens_list))
for tokens in tokens_list:
tokens = tokens.cpu().numpy().tolist()
sent = tokenizer.tokenizer.decode(tokens[raw_text_len:])
sent = sent.split("<|endoftext|>")[0]
sent = sent.split("\n\n\n")[0]
1 year ago
sent = sent.split("\n\n")[0]
sent = sent.split("Question:")[0]
sents.append(sent)
return sents
1 year ago
def generate_sample(model, tokenizer, input_txt):
input_ids = tokenizer.tokenizer.encode(input_txt)
raw_text_len = len(input_ids)
context_enc = torch.tensor([input_ids]).to(model.device)
1 year ago
print(f"Input text: {input_txt}\n")
outputs = model.generate(context_enc)
output_text = decode(outputs, tokenizer, raw_text_len)[0]
1 year ago
print(f"\nOutput text: {output_text}\n")
return output_text
def extract_answer_hf(completion):
match = ANS_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return eval(match_str)
else:
return INVALID_ANS
1 year ago
def extract_answer(completion):
try:
last_number = re.findall(r"\d+", completion)[-1]
1 year ago
return eval(last_number)
except:
return INVALID_ANS
def is_correct(completion, answer):
1 year ago
gold = extract_answer_hf(answer)
assert gold != INVALID_ANS, "No ground truth answer found in the document."
return extract_answer(completion) == gold
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test HF checkpoint.")
parser.add_argument(
"-c",
"--checkpoint-path",
type=str,
help="Checkpoint path",
default="Qwen/Qwen-7B",
)
parser.add_argument("-f", "--sample-input-file", type=str, default=None)
parser.add_argument(
"-o", "--sample-output-file", type=str, default="gsm8k_res.jsonl"
)
1 year ago
args = parser.parse_args()
fewshot_prompt = open("gsm8k_prompt.txt").read()
if args.sample_input_file is not None:
dataset = load_from_disk(args.sample_input_file)
else:
config = datasets.DownloadConfig(resume_download=True, max_retries=100)
dataset = load_dataset("gsm8k", "main", download_config=config)
1 year ago
test = dataset["test"]
print("Loading tokenizer ...")
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True
)
print("Loading model ...")
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path, device_map="auto", trust_remote_code=True
).eval()
model.generation_config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True
)
1 year ago
model.generation_config.do_sample = False
f_output = jsonlines.Writer(open(args.sample_output_file, "w", encoding="utf-8"))
1 year ago
tot_length = test.num_rows
acc_res = []
for doc in test:
context = doc_to_text(doc)
completion = generate_sample(model, tokenizer, context)
answer = doc["answer"]
1 year ago
acc = is_correct(completion, answer)
doc["completion"] = completion
doc["acc"] = acc
1 year ago
f_output.write(doc)
acc_res.append(acc)
1 year ago
f_output.close()
print("Acc: ", np.mean(acc_res))