-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Open
Labels
Description
Bug Description
When resuming training from a checkpoint using bf16 and the Muon optimizer, a RuntimeError occurs due to a dtype mismatch.
- Model parameters and gradients are in
bf16. - Optimizer state (
momentum_buffer) is loaded from the checkpoint asfp32. - The mismatch happens when Muon tries to apply updates (e.g.,
lerp_) betweenfp32momentum buffers andbf16gradients.
Minimal Reproducible Example
import torch
import os
import deepspeed
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
def train_step(model_engine, x, y):
output = model_engine(x)
loss = ((output - y) ** 2).mean()
model_engine.backward(loss)
model_engine.step()
hidden_size = 64
out_size = 1
dtype = torch.bfloat16
# Setup dummy data and model
x = torch.randn(16, hidden_size, dtype=dtype).cuda()
y = torch.randn(16, out_size, dtype=dtype).cuda()
model = torch.nn.Linear(hidden_size, out_size)
ds_config = {
"bf16": {"enabled": True},
"zero_optimization": {"stage": 2},
"optimizer": {"type": "Muon", "params": {"lr": 1e-3}},
"zero_allow_untested_optimizer": True,
"train_batch_size": 4,
"train_micro_batch_size_per_gpu": 1
}
model_engine, optimizer, _, _ = deepspeed.initialize(model=model, config=ds_config)
train_step(model_engine, x, y)
model_engine.save_checkpoint("./test_checkpoint")
model_engine.load_checkpoint("./test_checkpoint")
# Resume training -> trigger error
train_step(model_engine, x, y)Run command:
deepspeed --include localhost:0,1,2,3 example.pyError trace:
[rank2]: torch._dynamo.exc.TorchRuntimeError: Failed running call_method lerp_(*(FakeTensor(..., device='cuda:2', size=(1, 64)), FakeTensor(..., device='cuda:2', size=(1, 64), dtype=torch.bfloat16), 0.050000000000000044), **{}):
[rank2]: expected dtype torch.float32 for `end`, but got dtype torch.bfloat16
[rank2]: from user code:
[rank2]: File "/usr/local/lib/python3.12/dist-packages/deepspeed/runtime/zero/muon/original_muon.py", line 72, in torch_dynamo_resume_in_muon_update_at_71
[rank2]: momentum.lerp_(grad, 1 - beta)
Environment
- GPU:
4 H100 - Docker image:
nvidia/pytorch:25.03-py3 - torch:
2.7.0a0+7c8ec84dab.nv25.3 - deepspeed:
0.18.3
My Workaround
Debugging showed that momentum_buffer is loaded as an fp32 tensor, which conflicts with bf16 gradients in muon_update. I found that manually converting the buffer to bf16 right after load_checkpoint() fixes the crash:
for tensor_key, values_dict in optimizer.optimizer.state.items():
for key, tensor_value in values_dict.items():
if key == "momentum_buffer":
values_dict[key] = tensor_value.to(dtype=torch.bfloat16)This resolves the error, but I'm unsure if forcing momentum_buffer to bf16 is numerically stable or the intended behavior, given that optimizer states are typically kept in fp32. Opening this issue to find the proper fix.
Reactions are currently unavailable