From c00209f932c38c45ac885a4db7c2429885d93f5c Mon Sep 17 00:00:00 2001 From: yangapku Date: Mon, 30 Oct 2023 19:13:14 +0800 Subject: [PATCH] update evaluate scripts --- eval/evaluate_ceval.py | 84 +++++++++++++++++++++++++----------------- eval/evaluate_cmmlu.py | 84 +++++++++++++++++++++++++----------------- eval/evaluate_mmlu.py | 84 +++++++++++++++++++++++++----------------- 3 files changed, 153 insertions(+), 99 deletions(-) diff --git a/eval/evaluate_ceval.py b/eval/evaluate_ceval.py index a6618cf..1c1072e 100644 --- a/eval/evaluate_ceval.py +++ b/eval/evaluate_ceval.py @@ -20,13 +20,22 @@ python evaluate_ceval.py -d data/ceval/ def load_models_tokenizer(args): tokenizer = AutoTokenizer.from_pretrained( - args.checkpoint_path, trust_remote_code=True + args.checkpoint_path, + pad_token='<|extra_0|>', + eos_token='<|endoftext|>', + padding_side='left', + trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( - args.checkpoint_path, device_map="auto", trust_remote_code=True + args.checkpoint_path, + pad_token_id=tokenizer.pad_token_id, + device_map="auto", + trust_remote_code=True ).eval() model.generation_config = GenerationConfig.from_pretrained( - args.checkpoint_path, trust_remote_code=True + args.checkpoint_path, + pad_token_id=tokenizer.pad_token_id, + trust_remote_code=True ) return model, tokenizer @@ -56,11 +65,12 @@ def generate_few_shot_prompt(k, subject, dev_df): def get_logits(tokenizer, model, inputs: List[str]): - input_ids = tokenizer(inputs, padding=False)["input_ids"] + input_ids = tokenizer(inputs, padding='longest')["input_ids"] input_ids = torch.tensor(input_ids, device=model.device) tokens = {"input_ids": input_ids} + attention_mask = input_ids.ne(tokenizer.pad_token_id) - outputs = model(input_ids)["logits"] + outputs = model(input_ids, attention_mask=attention_mask)["logits"] logits = outputs[:, -1, :] log_probs = torch.nn.functional.softmax(logits, dim=-1) return log_probs, {"tokens": tokens} @@ -76,6 +86,7 @@ def eval_subject( dev_df=None, few_shot=False, save_result_dir=None, + batch_size=1, **kwargs, ): result = [] @@ -88,39 +99,39 @@ def eval_subject( if args.debug: print(f"few_shot_prompt: {few_shot_prompt}") - for _, row in tqdm(test_df.iterrows(), total=len(test_df)): - question = format_example(row, include_answer=False) - full_prompt = few_shot_prompt + question - - output, input_info = get_logits(tokenizer, model, [full_prompt]) - assert output.shape[0] == 1 - logits = output.flatten() - - softval = torch.nn.functional.softmax( - torch.tensor( - [ - logits[tokenizer("A")["input_ids"]], - logits[tokenizer("B")["input_ids"]], - logits[tokenizer("C")["input_ids"]], - logits[tokenizer("D")["input_ids"]], - ] - ), - dim=0, - ) + choices_ids = torch.tensor( + tokenizer("A")["input_ids"] + tokenizer("B")["input_ids"] + + tokenizer("C")["input_ids"] + tokenizer("D")["input_ids"] + ).unsqueeze(0).to(model.device) + + idx_list = list(range(0, len(test_df), batch_size)) + for i in tqdm(idx_list): + full_prompt_list = [] + answer_list = [] + for row in test_df.iloc[i:i+batch_size].to_dict(orient='records'): + question = format_example(row, include_answer=False) + full_prompt = few_shot_prompt + question + full_prompt_list.append(full_prompt) + if 'answer' in row: + answer_list.append(row['answer']) + + logits, input_info = get_logits(tokenizer, model, full_prompt_list) + softval = logits.gather(1, choices_ids.expand(logits.size(0), -1)).softmax(1) if softval.dtype in {torch.bfloat16, torch.float16}: softval = softval.to(dtype=torch.float32) probs = softval.detach().cpu().numpy() - for i, choice in enumerate(choices): - all_probs[f"prob_{choice}"].append(probs[i]) - pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)] + for i in range(len(probs)): + for j, choice in enumerate(choices): + all_probs[f"prob_{choice}"].append(probs[i][j]) + pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs[i])] - if "answer" in row: - correct = 1 if pred == row["answer"] else 0 - score.append(correct) - if args.debug: - print(f'{question} pred: {pred} ref: {row["answer"]}') - result.append(pred) + if answer_list != []: + correct = 1 if pred == answer_list[i] else 0 + score.append(correct) + if args.debug: + print(f'{question} pred: {pred} ref: {answer_list[i]}') + result.append(pred) if score: correct_ratio = 100 * sum(score) / len(score) @@ -395,6 +406,7 @@ def main(args): k=5, few_shot=True, save_result_dir=f"outs/ceval_eval_result", + batch_size=args.batch_size ) dev_result[subject_name] = score cal_ceval(dev_result) @@ -425,6 +437,12 @@ if __name__ == "__main__": group.add_argument( "--debug", action="store_true", default=False, help="Print infos." ) + group.add_argument( + "--batch-size", + type=int, + default=1, + help="batch size", + ) args = parser.parse_args() set_seed(args.seed) diff --git a/eval/evaluate_cmmlu.py b/eval/evaluate_cmmlu.py index 2d2371d..cdd5888 100644 --- a/eval/evaluate_cmmlu.py +++ b/eval/evaluate_cmmlu.py @@ -26,13 +26,22 @@ def load_models_tokenizer(args): from transformers.generation import GenerationConfig tokenizer = AutoTokenizer.from_pretrained( - args.checkpoint_path, trust_remote_code=True + args.checkpoint_path, + pad_token='<|extra_0|>', + eos_token='<|endoftext|>', + padding_side='left', + trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( - args.checkpoint_path, device_map="auto", trust_remote_code=True + args.checkpoint_path, + pad_token_id=tokenizer.pad_token_id, + device_map="auto", + trust_remote_code=True ).eval() model.generation_config = GenerationConfig.from_pretrained( - args.checkpoint_path, trust_remote_code=True + args.checkpoint_path, + pad_token_id=tokenizer.pad_token_id, + trust_remote_code=True ) return model, tokenizer @@ -62,11 +71,12 @@ def generate_few_shot_prompt(k, subject, dev_df): def get_logits(tokenizer, model, inputs: List[str]): - input_ids = tokenizer(inputs, padding=False)["input_ids"] + input_ids = tokenizer(inputs, padding='longest')["input_ids"] input_ids = torch.tensor(input_ids, device=model.device) tokens = {"input_ids": input_ids} + attention_mask = input_ids.ne(tokenizer.pad_token_id) - outputs = model(input_ids)["logits"] + outputs = model(input_ids, attention_mask=attention_mask)["logits"] logits = outputs[:, -1, :] log_probs = torch.nn.functional.softmax(logits, dim=-1) return log_probs, {"tokens": tokens} @@ -82,6 +92,7 @@ def eval_subject( dev_df=None, few_shot=False, save_result_dir=None, + batch_size=1, **kwargs, ): result = [] @@ -94,39 +105,39 @@ def eval_subject( if args.debug: print(f"few_shot_prompt: {few_shot_prompt}") - for _, row in tqdm(test_df.iterrows(), total=len(test_df)): - question = format_example(row, include_answer=False) - full_prompt = few_shot_prompt + question - - output, input_info = get_logits(tokenizer, model, [full_prompt]) - assert output.shape[0] == 1 - logits = output.flatten() - - softval = torch.nn.functional.softmax( - torch.tensor( - [ - logits[tokenizer("A")["input_ids"]], - logits[tokenizer("B")["input_ids"]], - logits[tokenizer("C")["input_ids"]], - logits[tokenizer("D")["input_ids"]], - ] - ), - dim=0, - ) + choices_ids = torch.tensor( + tokenizer("A")["input_ids"] + tokenizer("B")["input_ids"] + + tokenizer("C")["input_ids"] + tokenizer("D")["input_ids"] + ).unsqueeze(0).to(model.device) + + idx_list = list(range(0, len(test_df), batch_size)) + for i in tqdm(idx_list): + full_prompt_list = [] + answer_list = [] + for row in test_df.iloc[i:i+batch_size].to_dict(orient='records'): + question = format_example(row, include_answer=False) + full_prompt = few_shot_prompt + question + full_prompt_list.append(full_prompt) + if 'Answer' in row: + answer_list.append(row['Answer']) + + logits, input_info = get_logits(tokenizer, model, full_prompt_list) + softval = logits.gather(1, choices_ids.expand(logits.size(0), -1)).softmax(1) if softval.dtype in {torch.bfloat16, torch.float16}: softval = softval.to(dtype=torch.float32) probs = softval.detach().cpu().numpy() - for i, choice in enumerate(choices): - all_probs[f"prob_{choice}"].append(probs[i]) - pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)] + for i in range(len(probs)): + for j, choice in enumerate(choices): + all_probs[f"prob_{choice}"].append(probs[i][j]) + pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs[i])] - if "Answer" in row: - correct = 1 if pred == row["Answer"] else 0 - score.append(correct) - if args.debug: - print(f'{question} pred: {pred} ref: {row["Answer"]}') - result.append(pred) + if answer_list != []: + correct = 1 if pred == answer_list[i] else 0 + score.append(correct) + if args.debug: + print(f'{question} pred: {pred} ref: {answer_list[i]}') + result.append(pred) if score: correct_ratio = 100 * sum(score) / len(score) @@ -288,6 +299,7 @@ def main(args): k=5, few_shot=True, save_result_dir=f"outs/cmmlu_eval_result", + batch_size=args.batch_size ) test_result[subject_name] = score cal_cmmlu(test_result) @@ -318,6 +330,12 @@ if __name__ == "__main__": group.add_argument( "--debug", action="store_true", default=False, help="Print infos." ) + group.add_argument( + "--batch-size", + type=int, + default=1, + help="batch size", + ) args = parser.parse_args() set_seed(args.seed) diff --git a/eval/evaluate_mmlu.py b/eval/evaluate_mmlu.py index 2843434..9347dce 100644 --- a/eval/evaluate_mmlu.py +++ b/eval/evaluate_mmlu.py @@ -21,13 +21,22 @@ python eval/evaluate_mmlu.py -d data/mmlu/data/ def load_models_tokenizer(args): tokenizer = AutoTokenizer.from_pretrained( - args.checkpoint_path, trust_remote_code=True + args.checkpoint_path, + pad_token='<|extra_0|>', + eos_token='<|endoftext|>', + padding_side='left', + trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( - args.checkpoint_path, device_map="auto", trust_remote_code=True + args.checkpoint_path, + pad_token_id=tokenizer.pad_token_id, + device_map="auto", + trust_remote_code=True ).eval() model.generation_config = GenerationConfig.from_pretrained( - args.checkpoint_path, trust_remote_code=True + args.checkpoint_path, + pad_token_id=tokenizer.pad_token_id, + trust_remote_code=True ) return model, tokenizer @@ -67,14 +76,15 @@ def generate_few_shot_prompt(k, subject, dev_df): def get_logits(tokenizer, model, inputs: List[str]): - input_ids = tokenizer(inputs, padding=False)["input_ids"] + input_ids = tokenizer(inputs, padding='longest')["input_ids"] input_ids = torch.tensor(input_ids, device=model.device) if input_ids.shape[1] > args.max_seq_len: input_ids = input_ids[:, input_ids.shape[1] - args.max_seq_len + 1 :] tokens = {"input_ids": input_ids} + attention_mask = input_ids.ne(tokenizer.pad_token_id) - outputs = model(input_ids)["logits"] + outputs = model(input_ids, attention_mask=attention_mask)["logits"] logits = outputs[:, -1, :] log_probs = torch.nn.functional.softmax(logits, dim=-1) return log_probs, {"tokens": tokens} @@ -90,6 +100,7 @@ def eval_subject( dev_df=None, few_shot=False, save_result_dir=None, + batch_size=1, **kwargs, ): result = [] @@ -102,39 +113,39 @@ def eval_subject( if args.debug: print(f"few_shot_prompt: {few_shot_prompt}") - for _, row in tqdm(test_df.iterrows(), total=len(test_df)): - question = format_example(row, include_answer=False) - full_prompt = few_shot_prompt + question - - output, input_info = get_logits(tokenizer, model, [full_prompt]) - assert output.shape[0] == 1 - logits = output.flatten() - - softval = torch.nn.functional.softmax( - torch.tensor( - [ - logits[tokenizer(" A")["input_ids"]], - logits[tokenizer(" B")["input_ids"]], - logits[tokenizer(" C")["input_ids"]], - logits[tokenizer(" D")["input_ids"]], - ] - ), - dim=0, - ) + choices_ids = torch.tensor( + tokenizer(" A")["input_ids"] + tokenizer(" B")["input_ids"] + + tokenizer(" C")["input_ids"] + tokenizer(" D")["input_ids"] + ).unsqueeze(0).to(model.device) + + idx_list = list(range(0, len(test_df), batch_size)) + for i in tqdm(idx_list): + full_prompt_list = [] + answer_list = [] + for row in test_df.iloc[i:i+batch_size].to_dict(orient='records'): + question = format_example(row, include_answer=False) + full_prompt = few_shot_prompt + question + full_prompt_list.append(full_prompt) + if 'answer' in row: + answer_list.append(row['answer']) + + logits, input_info = get_logits(tokenizer, model, full_prompt_list) + softval = logits.gather(1, choices_ids.expand(logits.size(0), -1)).softmax(1) if softval.dtype in {torch.bfloat16, torch.float16}: softval = softval.to(dtype=torch.float32) probs = softval.detach().cpu().numpy() - for i, choice in enumerate(choices): - all_probs[f"prob_{choice}"].append(probs[i]) - pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)] + for i in range(len(probs)): + for j, choice in enumerate(choices): + all_probs[f"prob_{choice}"].append(probs[i][j]) + pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs[i])] - if "answer" in row: - correct = 1 if pred == row["answer"] else 0 - score.append(correct) - if args.debug: - print(f'{question} pred: {pred} ref: {row["answer"]}') - result.append(pred) + if answer_list != []: + correct = 1 if pred == answer_list[i] else 0 + score.append(correct) + if args.debug: + print(f'{question} pred: {pred} ref: {answer_list[i]}') + result.append(pred) if save_result_dir: test_df["model_output"] = result @@ -209,6 +220,7 @@ def main(args): k=5, few_shot=True, save_result_dir=f"outs/mmlu_eval_result", + batch_size=args.batch_size ) dev_result[subject_name] = score cal_mmlu(dev_result) @@ -308,6 +320,12 @@ if __name__ == "__main__": group.add_argument( "--debug", action="store_true", default=False, help="Print infos." ) + group.add_argument( + "--batch-size", + type=int, + default=1, + help="batch size", + ) args = parser.parse_args() set_seed(args.seed)