|
|
|
@ -278,7 +278,7 @@ def train():
|
|
|
|
|
|
|
|
|
|
local_rank = training_args.local_rank
|
|
|
|
|
|
|
|
|
|
device_map = None
|
|
|
|
|
device_map = "auto"
|
|
|
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
|
|
|
ddp = world_size != 1
|
|
|
|
|
if lora_args.q_lora:
|
|
|
|
@ -302,6 +302,7 @@ def train():
|
|
|
|
|
config=config,
|
|
|
|
|
cache_dir=training_args.cache_dir,
|
|
|
|
|
device_map=device_map,
|
|
|
|
|
low_cpu_mem_usage=True if training_args.use_lora and not lora_args.q_lora,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
quantization_config=GPTQConfig(
|
|
|
|
|
bits=4, disable_exllama=True
|
|
|
|
|