Skip to content

Conversation

@lihongyang1990
Copy link

Summary

Refactor and improve the FlagOS optimizer and multi_tensor implementations to better match CUDA behavior and improve code quality.

Changes

fused_adam.py (FlagOS backend)

  • Remove unused inv_scale and out_dtype parameters from multi_tensor_adam_fl
  • multi_tensor_adam_param_remainder_fl: rewrite FP32 master weight reconstruction using bit manipulation (int16 high/low bits), matching the CUDA implementation exactly

multi_tensor.py (FlagOS backend)

  • multi_tensor_l2_norm_fl: add proper type hints, noop_flag check, inf/nan detection, and replace raw ** / + operators with flag_gems.mul / flag_gems.add
  • multi_tensor_scale_fl: add type hints, noop_flag check, inf/nan detection, and replace src * scale with flag_gems.mul(src, scale)

optimizer.py (reference backend)

  • Update multi_tensor_l2norm_torch and multi_tensor_adam_torch to match new signatures and CUDA behavior (L2 vs AdamW mode split)
  • Rewrite multi_tensor_adam_param_remainder_torch with bit manipulation matching CUDA
  • Rename epsepsilon for consistency

optimizers/__init__.py

  • Export multi_tensor_scale and multi_tensor_l2norm

Misc

  • Fix missing newline at end of files

- Refactor multi_tensor_adam_fl: remove unused inv_scale and out_dtype params
- Refactor multi_tensor_adam_param_remainder_fl: use bit manipulation for
  BF16/int16 FP32 reconstruction matching CUDA implementation exactly
- Improve multi_tensor_l2_norm_fl and multi_tensor_scale_fl: add type hints,
  inf/nan checks, and replace raw operators with flag_gems.mul/flag_gems.add
- Update reference optimizer impl to match new signatures and behavior
- Export multi_tensor_scale and multi_tensor_l2norm in pytorch optimizers init
- Fix missing newline at end of files
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant