Skip to content

fix: resolve OOM in long-sequence training via conditional entropy gradient tracking#1524

Open
ppraneth wants to merge 1 commit intoTHUDM:mainfrom
ppraneth:oom
Open

fix: resolve OOM in long-sequence training via conditional entropy gradient tracking#1524
ppraneth wants to merge 1 commit intoTHUDM:mainfrom
ppraneth:oom

Conversation

@ppraneth
Copy link
Contributor

This PR addresses the CUDA Out of Memory (OOM) issues #1523 encountered during training with long sequences (e.g., >30k tokens) by implementing conditional gradient tracking for the entropy term.

The Problem

In the previous implementation, the entropy calculation was always differentiable, regardless of the entropy_coef value. To support the backward pass, the system was forced to store massive intermediate activation tensors (logits and softmax outputs) with shapes of (seq_len, vocab_size / TP). For long sequences and large vocabularies, these tensors consumed 5–6 GB of VRAM per sample on the last pipeline stage, leading to OOM even when the entropy contribution to the loss was zero.

The Solution

I have decoupled entropy for monitoring (logging) from entropy for training (loss).

  • Conditional Logic: Gradients for the entropy term are now only tracked if args.entropy_coef > 0.
  • Memory Recovery: When entropy_coef is 0.0, entropy for logging is computed within a torch.no_grad() context. This prevents PyTorch from allocating memory for backward tensors, effectively reclaiming several gigabytes of VRAM per sequence.
  • Backend Support: This fix has been applied to both the Megatron and FSDP backends to ensure consistency across configurations.

Files Changed

  • slime/utils/ppo_utils.py: Added requires_entropy_grad flag and no_grad context to calculate_log_probs_and_entropy.
  • slime/backends/megatron_utils/loss.py: Passed gradient requirement flag based on entropy_coef.
  • slime/backends/fsdp_utils/actor.py: Implemented conditional tracking and added contextlib import.

@lilei199908
Copy link
Collaborator

thanks, we are refactoring this part for better menmory use.

@ppraneth
Copy link
Contributor Author

@lilei199908 could you review this PR and suggest any edits? I believe this should reduce memory usage.

@lilei199908
Copy link
Collaborator

@lilei199908 could you review this PR and suggest any edits? I believe this should reduce memory usage.

LGTM, we will merged it soon! thanks

@ppraneth
Copy link
Contributor Author

ppraneth commented Feb 3, 2026

@zhuzilin Can you check and merge this pr?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants