|
|
|
@ -291,9 +291,9 @@ def train():
|
|
|
|
|
):
|
|
|
|
|
raise RuntimeError("ZeRO3 is incompatible with LoRA when finetuning on base model.")
|
|
|
|
|
|
|
|
|
|
model_load_kwargs = {}
|
|
|
|
|
if deepspeed.is_deepspeed_zero3_enabled():
|
|
|
|
|
model_load_kwargs['low_cpu_mem_usage'] = False
|
|
|
|
|
model_load_kwargs = {
|
|
|
|
|
'low_cpu_mem_usage': not deepspeed.is_deepspeed_zero3_enabled(),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Set RoPE scaling factor
|
|
|
|
|
config = transformers.AutoConfig.from_pretrained(
|
|
|
|
|