diff --git a/finetune.py b/finetune.py index a4fee87..ccd5f70 100644 --- a/finetune.py +++ b/finetune.py @@ -278,7 +278,7 @@ def train(): local_rank = training_args.local_rank - device_map = None + device_map = "auto" world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 if lora_args.q_lora: @@ -302,6 +302,7 @@ def train(): config=config, cache_dir=training_args.cache_dir, device_map=device_map, + low_cpu_mem_usage=True, trust_remote_code=True, quantization_config=GPTQConfig( bits=4, disable_exllama=True