import os import argparse import re import torch import pandas as pd from tqdm import tqdm from thefuzz import process from transformers.trainer_utils import set_seed from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig ''' wget https://people.eecs.berkeley.edu/~hendrycks/data.tar mkdir data/mmlu mv data.tar data/mmlu cd data/mmlu; tar xf data.tar cd ../../ pip install thefuzz python eval/evaluate_chat_mmlu.py -d data/mmlu/data/ ''' def load_models_tokenizer(args): tokenizer = AutoTokenizer.from_pretrained( args.checkpoint_path, trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( args.checkpoint_path, device_map="auto", trust_remote_code=True, bf16=True, use_flash_attn=True, ).eval() model.generation_config = GenerationConfig.from_pretrained( args.checkpoint_path, trust_remote_code=True ) model.generation_config.do_sample = False # use greedy decoding model.generation_config.repetition_penalty = 1.0 # disable repetition penalty return model, tokenizer def format_example(line): example = ( "The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n" + line["question"] + "\n" ) for choice in choices: example += f'{choice}. {line[f"{choice}"]}\n' return example def process_before_extraction(gen, choice_dict): # replace the choice by letter in the generated sentence # from longest one to shortest one for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1]), reverse=True): pattern = re.compile(re.escape(val.rstrip(".")), re.IGNORECASE) gen = pattern.sub(key, gen) return gen def extract_choice(gen, choice_list): # answer is A | choice is A | choose A res = re.search( r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b", gen, ) # A is correct | A is right if res is None: res = re.search( r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b", gen, ) # straight answer: A if res is None: res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen) # simply extract the first appearred letter if res is None: res = re.search(r"(?