|
|
@ -93,6 +93,7 @@ if __name__ == "__main__":
|
|
|
|
args.checkpoint_path, trust_remote_code=True
|
|
|
|
args.checkpoint_path, trust_remote_code=True
|
|
|
|
)
|
|
|
|
)
|
|
|
|
model.generation_config.do_sample = False # use greedy decoding
|
|
|
|
model.generation_config.do_sample = False # use greedy decoding
|
|
|
|
|
|
|
|
model.generation_config.repetition_penalty = 1.0 # disable repetition penalty
|
|
|
|
|
|
|
|
|
|
|
|
f_output = jsonlines.Writer(open(args.sample_output_file, "w", encoding="utf-8"))
|
|
|
|
f_output = jsonlines.Writer(open(args.sample_output_file, "w", encoding="utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|