specify repetition penalty

main
yangapku 1 year ago
parent 024146bc82
commit f076e2fa42

@ -31,6 +31,7 @@ def load_models_tokenizer(args):
args.checkpoint_path, trust_remote_code=True args.checkpoint_path, trust_remote_code=True
) )
model.generation_config.do_sample = False # use greedy decoding model.generation_config.do_sample = False # use greedy decoding
model.generation_config.repetition_penalty = 1.0 # disable repetition penalty
return model, tokenizer return model, tokenizer
def process_before_extraction(gen, question, choice_dict): def process_before_extraction(gen, question, choice_dict):

@ -129,6 +129,7 @@ if __name__ == "__main__":
args.checkpoint_path, trust_remote_code=True args.checkpoint_path, trust_remote_code=True
) )
model.generation_config.do_sample = False # use greedy decoding model.generation_config.do_sample = False # use greedy decoding
model.generation_config.repetition_penalty = 1.0 # disable repetition penalty
test = dataset["test"] test = dataset["test"]

@ -93,6 +93,7 @@ if __name__ == "__main__":
args.checkpoint_path, trust_remote_code=True args.checkpoint_path, trust_remote_code=True
) )
model.generation_config.do_sample = False # use greedy decoding model.generation_config.do_sample = False # use greedy decoding
model.generation_config.repetition_penalty = 1.0 # disable repetition penalty
f_output = jsonlines.Writer(open(args.sample_output_file, "w", encoding="utf-8")) f_output = jsonlines.Writer(open(args.sample_output_file, "w", encoding="utf-8"))

@ -35,6 +35,7 @@ def load_models_tokenizer(args):
args.checkpoint_path, trust_remote_code=True args.checkpoint_path, trust_remote_code=True
) )
model.generation_config.do_sample = False # use greedy decoding model.generation_config.do_sample = False # use greedy decoding
model.generation_config.repetition_penalty = 1.0 # disable repetition penalty
return model, tokenizer return model, tokenizer

Loading…
Cancel
Save