diff --git a/finetune.py b/finetune.py index dd7df12..4a4e334 100644 --- a/finetune.py +++ b/finetune.py @@ -291,9 +291,9 @@ def train(): ): raise RuntimeError("ZeRO3 is incompatible with LoRA when finetuning on base model.") - model_load_kwargs = {} - if deepspeed.is_deepspeed_zero3_enabled(): - model_load_kwargs['low_cpu_mem_usage'] = False + model_load_kwargs = { + 'low_cpu_mem_usage': not deepspeed.is_deepspeed_zero3_enabled(), + } # Set RoPE scaling factor config = transformers.AutoConfig.from_pretrained(