You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

334 lines
9.7 KiB
Python

1 year ago
import os
from typing import List
1 year ago
import pandas as pd
import numpy as np
import argparse
import torch
from tqdm import tqdm
from transformers.trainer_utils import set_seed
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
1 year ago
"""
1 year ago
wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
mkdir data/mmlu
mv data.tar data/mmlu
cd data/mmlu; tar xf data.tar
cd ../../
python eval/evaluate_mmlu.py -d data/mmlu/data/
"""
1 year ago
def load_models_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path,
pad_token='<|extra_0|>',
eos_token='<|endoftext|>',
padding_side='left',
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
pad_token_id=tokenizer.pad_token_id,
device_map="auto",
trust_remote_code=True
).eval()
model.generation_config = GenerationConfig.from_pretrained(
args.checkpoint_path,
pad_token_id=tokenizer.pad_token_id,
trust_remote_code=True
)
1 year ago
return model, tokenizer
def format_example(line, include_answer=True):
example = "Question: " + line["question"]
1 year ago
for choice in choices:
example += f'\n{choice}. {line[f"{choice}"]}'
1 year ago
if include_answer:
example += "\nAnswer: " + line["answer"] + "\n\n"
1 year ago
else:
example += "\nAnswer:"
1 year ago
return example
def generate_few_shot_prompt(k, subject, dev_df):
def format_subject(subject):
l = subject.split("_")
s = ""
for entry in l:
s += " " + entry
return s.strip()
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
format_subject(subject)
)
1 year ago
if k == -1:
k = dev_df.shape[0]
for i in range(k):
prompt += format_example(
dev_df.iloc[i, :],
include_answer=True,
)
return prompt
def get_logits(tokenizer, model, inputs: List[str]):
input_ids = tokenizer(inputs, padding='longest')["input_ids"]
1 year ago
input_ids = torch.tensor(input_ids, device=model.device)
if input_ids.shape[1] > args.max_seq_len:
input_ids = input_ids[:, input_ids.shape[1] - args.max_seq_len + 1 :]
tokens = {"input_ids": input_ids}
attention_mask = input_ids.ne(tokenizer.pad_token_id)
1 year ago
outputs = model(input_ids, attention_mask=attention_mask)["logits"]
1 year ago
logits = outputs[:, -1, :]
log_probs = torch.nn.functional.softmax(logits, dim=-1)
return log_probs, {"tokens": tokens}
1 year ago
@torch.no_grad()
def eval_subject(
model,
tokenizer,
subject_name,
test_df,
k=5,
dev_df=None,
few_shot=False,
save_result_dir=None,
batch_size=1,
**kwargs,
1 year ago
):
result = []
score = []
few_shot_prompt = (
generate_few_shot_prompt(k, subject_name, dev_df) if few_shot else []
)
all_probs = {"prob_A": [], "prob_B": [], "prob_C": [], "prob_D": []}
if args.debug:
print(f"few_shot_prompt: {few_shot_prompt}")
1 year ago
choices_ids = torch.tensor(
tokenizer(" A")["input_ids"] + tokenizer(" B")["input_ids"] +
tokenizer(" C")["input_ids"] + tokenizer(" D")["input_ids"]
).unsqueeze(0).to(model.device)
idx_list = list(range(0, len(test_df), batch_size))
for i in tqdm(idx_list):
full_prompt_list = []
answer_list = []
for row in test_df.iloc[i:i+batch_size].to_dict(orient='records'):
question = format_example(row, include_answer=False)
full_prompt = few_shot_prompt + question
full_prompt_list.append(full_prompt)
if 'answer' in row:
answer_list.append(row['answer'])
logits, input_info = get_logits(tokenizer, model, full_prompt_list)
softval = logits.gather(1, choices_ids.expand(logits.size(0), -1)).softmax(1)
1 year ago
if softval.dtype in {torch.bfloat16, torch.float16}:
softval = softval.to(dtype=torch.float32)
probs = softval.detach().cpu().numpy()
for i in range(len(probs)):
for j, choice in enumerate(choices):
all_probs[f"prob_{choice}"].append(probs[i][j])
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs[i])]
1 year ago
if answer_list != []:
correct = 1 if pred == answer_list[i] else 0
score.append(correct)
if args.debug:
print(f'{question} pred: {pred} ref: {answer_list[i]}')
result.append(pred)
1 year ago
if save_result_dir:
test_df["model_output"] = result
1 year ago
for i, choice in enumerate(choices):
test_df[f"prob_{choice}"] = all_probs[f"prob_{choice}"]
1 year ago
if score:
test_df["correctness"] = score
os.makedirs(save_result_dir, exist_ok=True)
test_df.to_csv(
os.path.join(save_result_dir, f"{subject_name}_result.csv"),
encoding="utf-8",
index=False,
)
1 year ago
return score
def cal_mmlu(res):
acc_sum_dict = dict()
acc_norm_sum_dict = dict()
cnt_dict = dict()
acc_sum = 0.0
1 year ago
cnt = 0
hard_cnt = 0
hard_acc_sum = 0.0
1 year ago
for class_ in TASK_NAME_MAPPING.keys():
acc_sum_dict[class_] = 0.0
acc_norm_sum_dict[class_] = 0.0
cnt_dict[class_] = 0.0
1 year ago
for tt in TASK_NAME_MAPPING[class_]:
acc_sum += sum(res[tt])
cnt += len(res[tt])
acc_sum_dict[class_] += sum(res[tt])
cnt_dict[class_] += len(res[tt])
print("\n\n\n", "total cnt:", cnt, "\n")
1 year ago
for k in TASK_NAME_MAPPING.keys():
if k in cnt_dict:
print("%s ACC: %.2f " % (k, acc_sum_dict[k] / cnt_dict[k] * 100))
print("AVERAGE ACC:%.2f " % (acc_sum / cnt * 100))
1 year ago
def main(args):
model, tokenizer = load_models_tokenizer(args)
dev_result = {}
for subject_name in tqdm(SUBJECTS):
# val_file_path = os.path.join(args.eval_data_path, 'val', f'{subject_name}_val.csv')
dev_file_path = os.path.join(
args.eval_data_path, "dev", f"{subject_name}_dev.csv"
)
test_file_path = os.path.join(
args.eval_data_path, "test", f"{subject_name}_test.csv"
)
1 year ago
# val_df = pd.read_csv(val_file_path, names=['question','A','B','C','D','answer'])
dev_df = pd.read_csv(
dev_file_path, names=["question", "A", "B", "C", "D", "answer"]
)
test_df = pd.read_csv(
test_file_path, names=["question", "A", "B", "C", "D", "answer"]
)
1 year ago
score = eval_subject(
model,
tokenizer,
subject_name,
test_df,
dev_df=dev_df,
k=5,
few_shot=True,
save_result_dir=f"outs/mmlu_eval_result",
batch_size=args.batch_size
)
1 year ago
dev_result[subject_name] = score
cal_mmlu(dev_result)
TASK_NAME_MAPPING = {
"stem": [
"abstract_algebra",
"anatomy",
"astronomy",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_physics",
"computer_security",
"conceptual_physics",
"electrical_engineering",
"elementary_mathematics",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_mathematics",
"high_school_physics",
"high_school_statistics",
"machine_learning",
],
"Humanities": [
"formal_logic",
"high_school_european_history",
"high_school_us_history",
"high_school_world_history",
"international_law",
"jurisprudence",
"logical_fallacies",
"moral_disputes",
"moral_scenarios",
"philosophy",
"prehistory",
"professional_law",
"world_religions",
],
"other": [
"business_ethics",
"college_medicine",
"human_aging",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"nutrition",
"professional_accounting",
"professional_medicine",
"virology",
"global_facts",
"clinical_knowledge",
],
"social": [
"econometrics",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_microeconomics",
"high_school_psychology",
"human_sexuality",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
],
}
1 year ago
SUBJECTS = [v for vl in TASK_NAME_MAPPING.values() for v in vl]
choices = ["A", "B", "C", "D"]
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test HF checkpoint.")
parser.add_argument(
"-c",
"--checkpoint-path",
type=str,
help="Checkpoint path",
default="Qwen/Qwen-7B",
)
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
parser.add_argument("--gpu", type=int, default=0, help="gpu id")
1 year ago
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title="Evaluation options")
group.add_argument("-d", "--eval_data_path", type=str, help="Path to eval data")
group.add_argument(
"--max-seq-len",
type=int,
default=2048,
help="Size of the output generated text.",
)
group.add_argument(
"--debug", action="store_true", default=False, help="Print infos."
)
group.add_argument(
"--batch-size",
type=int,
default=1,
help="batch size",
)
1 year ago
args = parser.parse_args()
set_seed(args.seed)
main(args)