fix bug for ceval

main
qinxy3 2 years ago
parent 460ea3418b
commit bff91b3305

@ -80,7 +80,7 @@ def eval_subject(
score = [] score = []
few_shot_prompt = generate_few_shot_prompt( few_shot_prompt = generate_few_shot_prompt(
k, subject_name, dev_df) if few_shot else [] k, subject_name, dev_df) if few_shot else ''
all_probs = {'prob_A': [], 'prob_B': [], 'prob_C': [], 'prob_D': []} all_probs = {'prob_A': [], 'prob_B': [], 'prob_C': [], 'prob_D': []}
if args.debug: print(f"few_shot_prompt: {few_shot_prompt}") if args.debug: print(f"few_shot_prompt: {few_shot_prompt}")
@ -95,10 +95,10 @@ def eval_subject(
softval = torch.nn.functional.softmax( softval = torch.nn.functional.softmax(
torch.tensor( torch.tensor(
[ [
logits[tokenizer("A")['input_ids']], logits[tokenizer("A")['input_ids'][-1]],
logits[tokenizer("B")['input_ids']], logits[tokenizer("B")['input_ids'][-1]],
logits[tokenizer("C")['input_ids']], logits[tokenizer("C")['input_ids'][-1]],
logits[tokenizer("D")['input_ids']], logits[tokenizer("D")['input_ids'][-1]],
] ]
), ),
dim=0, dim=0,

Loading…
Cancel
Save