|
|
@ -270,12 +270,6 @@ def train():
|
|
|
|
if getattr(training_args, 'deepspeed', None) and int(os.environ.get("WORLD_SIZE", 1))==1:
|
|
|
|
if getattr(training_args, 'deepspeed', None) and int(os.environ.get("WORLD_SIZE", 1))==1:
|
|
|
|
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
|
|
|
|
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
|
|
|
|
|
|
|
|
|
|
|
|
compute_dtype = (
|
|
|
|
|
|
|
|
torch.float16
|
|
|
|
|
|
|
|
if training_args.fp16
|
|
|
|
|
|
|
|
else (torch.bfloat16 if training_args.bf16 else torch.float32)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_rank = training_args.local_rank
|
|
|
|
local_rank = training_args.local_rank
|
|
|
|
|
|
|
|
|
|
|
|
device_map = "auto"
|
|
|
|
device_map = "auto"
|
|
|
|