Refactor optimizer implementations and improve multi_tensor ops #36
+254
−144
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Refactor and improve the FlagOS optimizer and multi_tensor implementations to better match CUDA behavior and improve code quality.
Changes
fused_adam.py(FlagOS backend)inv_scaleandout_dtypeparameters frommulti_tensor_adam_flmulti_tensor_adam_param_remainder_fl: rewrite FP32 master weight reconstruction using bit manipulation (int16 high/low bits), matching the CUDA implementation exactlymulti_tensor.py(FlagOS backend)multi_tensor_l2_norm_fl: add proper type hints, noop_flag check, inf/nan detection, and replace raw**/+operators withflag_gems.mul/flag_gems.addmulti_tensor_scale_fl: add type hints, noop_flag check, inf/nan detection, and replacesrc * scalewithflag_gems.mul(src, scale)optimizer.py(reference backend)multi_tensor_l2norm_torchandmulti_tensor_adam_torchto match new signatures and CUDA behavior (L2 vs AdamW mode split)multi_tensor_adam_param_remainder_torchwith bit manipulation matching CUDAeps→epsilonfor consistencyoptimizers/__init__.pymulti_tensor_scaleandmulti_tensor_l2normMisc