Skip to content

Commit 09885ef

Browse files
authored
Fix BF16_Optimizer being used without ZeRO (#7790)
The test_ds_initialize.py::TestOptimizerImplementation expects the configuration (None, 'bf16', 'fp32') is unimplemented. However, it is actually supported by DeepSpeed and uses FP16_Optimizer in bf16 mode. This PR adds the configuration in the dict of supported combinations. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
1 parent 1393f75 commit 09885ef

File tree

1 file changed

+0
-2
lines changed

1 file changed

+0
-2
lines changed

deepspeed/runtime/engine.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,8 +1481,6 @@ def _do_optimizer_sanity_check(self, basic_optimizer):
14811481
)
14821482
return BFLOAT16
14831483
return FP16 if model_dtype == torch.float16 else DDP_BFLOAT16
1484-
elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32:
1485-
return BFLOAT16
14861484
else:
14871485
raise NotImplementedError(f"unsupported mix of {model_dtype=} and {grad_accum_dtype=}")
14881486

0 commit comments

Comments
 (0)