Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 57 additions & 63 deletions transformer_engine/plugin/core/backends/flagos/impl/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# See LICENSE for license information.

from typing import Optional, List
from typing import List
import torch
import flag_gems

Expand All @@ -19,8 +19,6 @@ def multi_tensor_adam_fl(
mode: int,
bias_correction: int,
weight_decay: float,
inv_scale: Optional[float] = 1.0,
out_dtype: Optional[torch.dtype] = None,
) -> None:

num_lists = len(tensor_lists)
Expand Down Expand Up @@ -50,9 +48,6 @@ def multi_tensor_adam_fl(
if not g.is_contiguous():
g = g.contiguous()

if inv_scale is not None and inv_scale != 1.0:
g = flag_gems.mul(g, inv_scale)

m = flag_gems.add_(flag_gems.mul_(m, beta1), g, alpha=1-beta1)
v = flag_gems.add_(flag_gems.mul_(v, beta2), flag_gems.mul_(flag_gems.mul_(g, g), 1 - beta2))

Expand All @@ -73,9 +68,6 @@ def multi_tensor_adam_fl(

if p_master is not None:
flag_gems.copy_(p_master, p)
out_dtype = p_master.dtype if out_dtype is None else out_dtype
p.data = p.data.to(out_dtype)


def multi_tensor_adam_param_remainder_fl(
chunk_size: int,
Expand All @@ -89,27 +81,9 @@ def multi_tensor_adam_param_remainder_fl(
mode: int,
bias_correction: int,
weight_decay: float,
inv_scale: Optional[float] = 1.0,
) -> None:
"""
Adam optimizer with parameter remainders for BF16 precision (FlagOS implementation).

This variant stores BF16 parameters + int16 remainders to reconstruct FP32 master weights.
Used when you have BF16 params and need FP32 master params without storing full FP32 copies.

Args:
chunk_size: Chunk size for processing (unused in this implementation)
noop_flag: If non-zero, skip computation
tensor_lists: [grads, params (bf16), exp_avgs (fp32), exp_avg_sqs (fp32), param_remainders (int16)]
lr: Learning rate
beta1: First moment decay rate
beta2: Second moment decay rate
eps: Epsilon for numerical stability
step: Current optimization step
mode: 0 = L2 regularization, 1 = AdamW (decoupled weight decay)
bias_correction: Whether to apply bias correction (1 = yes, 0 = no)
weight_decay: Weight decay coefficient
inv_scale: Inverse gradient scale for mixed precision training
"""
if noop_flag.item() != 0:
return
Expand All @@ -133,58 +107,78 @@ def multi_tensor_adam_param_remainder_fl(

for i in range(num_tensors):
g = tensor_lists[0][i]
p = tensor_lists[1][i] # BF16 parameter
p = tensor_lists[1][i] # int16 parameter (high 16 bits of FP32)
m = tensor_lists[2][i] # FP32 first moment
v = tensor_lists[3][i] # FP32 second moment
p_remainder = tensor_lists[4][i] # int16 remainder
p_remainder = tensor_lists[4][i] # int16 remainder (low 16 bits of FP32)

if not g.is_contiguous():
g = g.contiguous()

# Apply gradient unscaling if needed
if inv_scale is not None and inv_scale != 1.0:
g = flag_gems.mul(g, inv_scale)
# Convert gradient to float
g_float = g.float()

# Reconstruct FP32 master weight from int16 param + int16 remainder using bit manipulation
# This matches the CUDA implementation exactly:
# 1. If p_remainder < 0, decrement p (undo rounding)
# 2. Combine high 16 bits (p) and low 16 bits (p_remainder) into FP32
# Note: Use PyTorch native ops for bit manipulation (int16/int32 operations)

# Reconstruct FP32 master weight from BF16 param + int16 remainder
# The remainder represents the lower 16 bits lost in BF16 conversion
param_fp32 = p.float()
param_master = flag_gems.add(param_fp32, flag_gems.mul(p_remainder.float(), 2.0 ** -16))
local_p = p.view(torch.int16).clone()
local_p_rem = p_remainder.clone()

# Compute gradient with weight decay (if L2 mode)
grad_with_decay = g.float()
if not is_adamw: # L2 regularization mode
grad_with_decay = flag_gems.add(grad_with_decay, flag_gems.mul(param_master, weight_decay))
# Undo rounding: if remainder < 0, decrement p
local_p = torch.where(local_p_rem < 0, local_p - 1, local_p)

# Update moments
m = flag_gems.add_(flag_gems.mul_(m, beta1), grad_with_decay, alpha=1 - beta1)
v = flag_gems.add_(flag_gems.mul_(v, beta2), flag_gems.mul_(flag_gems.mul_(grad_with_decay, grad_with_decay), 1 - beta2))
# Combine into FP32 using bit shift operations
# local_p is high 16 bits, local_p_rem is low 16 bits
high_bits = local_p.to(torch.int32) << 16
low_bits = local_p_rem.to(torch.int32) & 0xFFFF # Mask off sign extension
param_int32 = high_bits | low_bits
param_master = param_int32.view(torch.float32)

# L2 mode: add weight decay to gradient before updating moments
if not is_adamw and weight_decay != 0:
g_float = flag_gems.add(g_float, param_master, alpha=weight_decay)

# Update first moment: m = beta1 * m + (1 - beta1) * g
flag_gems.add_(flag_gems.mul_(m, beta1), g_float, alpha=1 - beta1)

# Update second moment: v = beta2 * v + (1 - beta2) * g^2
flag_gems.add_(flag_gems.mul_(v, beta2), flag_gems.mul(g_float, g_float), alpha=1 - beta2)

# Apply bias correction
m_corr = m.clone()
v_corr = v.clone()
if bias_correction == 1:
m_corr = flag_gems.true_divide(m_corr, bias_correction1)
v_corr = flag_gems.true_divide(v_corr, bias_correction2)
m_corr = flag_gems.true_divide(m, bias_correction1)
v_corr = flag_gems.true_divide(v, bias_correction2)

# Compute denominator: sqrt(v_corr) + eps
denom = flag_gems.add(flag_gems.sqrt(v_corr), eps)

# Compute update
update = flag_gems.true_divide(m_corr, flag_gems.add(flag_gems.sqrt(v_corr), eps))
update = flag_gems.true_divide(m_corr, denom)

# Apply weight decay (if AdamW mode)
if is_adamw:
param_master = flag_gems.mul_(param_master, 1 - lr * weight_decay)
# AdamW mode: add decoupled weight decay to update
if is_adamw and weight_decay != 0:
update = flag_gems.add(update, param_master, alpha=weight_decay)

# Update master weight: p = p - lr * update
param_master = flag_gems.sub(param_master, flag_gems.mul(update, lr))

# Update master weight
param_master = flag_gems.add_(param_master, update, alpha=-lr)
# Split FP32 back into int16 param + int16 remainder using bit manipulation
# This matches the CUDA implementation exactly:
# 1. Extract high 16 bits as p
# 2. Extract low 16 bits as p_remainder
# 3. If p_remainder < 0, increment p (round up)
# Note: Use PyTorch native ops for bit manipulation (int32 operations)

# Split back into BF16 param + int16 remainder
# Convert to BF16 (this is the rounded version)
param_bf16 = param_master.to(dtype=p.dtype)
param_int32 = param_master.view(torch.int32)
# Extract low 16 bits (remainder) and high 16 bits (param)
new_p_rem = (param_int32 & 0xFFFF).to(torch.int16)
new_p = ((param_int32 >> 16) & 0xFFFF).to(torch.int16)

# Compute remainder: difference between FP32 master and BF16 representation
# Scale and quantize to int16 range
remainder_fp32 = flag_gems.mul(flag_gems.sub(param_master, param_bf16.float()), 2.0 ** 16)
remainder_int16 = flag_gems.clamp(torch.round(remainder_fp32), -32768, 32767).to(dtype=torch.int16)
# Round up: if remainder < 0, increment p
new_p = torch.where(new_p_rem < 0, new_p + 1, new_p)

# Write back
flag_gems.copy_(p, param_bf16)
flag_gems.copy_(p_remainder, remainder_int16)
flag_gems.copy_(p, new_p.view(torch.bfloat16))
flag_gems.copy_(p_remainder, new_p_rem)
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,66 @@
#
# See LICENSE for license information.

from typing import List, Tuple
import torch
from torch.distributed._tensor import DTensor
import flag_gems


def multi_tensor_l2_norm_fl(chunk_size, noop_flag, tensor_lists, per_tensor, *args):
def multi_tensor_l2_norm_fl(
_chunk_size: int,
noop_flag: torch.Tensor,
tensor_lists: List[List[torch.Tensor]],
per_tensor: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute L2 norm of tensors using flag_gems.

Returns:
Tuple of (total_norm, per_tensor_norms_or_dummy)
- total_norm: The combined L2 norm of all tensors
- per_tensor_norms_or_dummy: Per-tensor norms stacked if per_tensor=True, else dummy tensor
"""
device = tensor_lists[0][0].device if tensor_lists and tensor_lists[0] else 'cpu'

if noop_flag.item() != 0:
return torch.tensor(0.0, device=device), torch.tensor(0.0, device=device)

tensors = tensor_lists[0]

# Compute per-tensor norms
per_tensor_norms = []
total_norm_sq = torch.tensor(0.0, device=device)

for tensor in tensors:
t_float = tensor.float()
norm_sq = flag_gems.sum(flag_gems.mul(t_float, t_float))
# Check for inf/nan (matches CUDA behavior)
if not torch.isfinite(norm_sq):
noop_flag.fill_(1)
total_norm_sq = flag_gems.add(total_norm_sq, norm_sq)
if per_tensor:
per_tensor_norms.append(flag_gems.sqrt(norm_sq))

total_norm = flag_gems.sqrt(total_norm_sq)

if per_tensor:
norms = [torch.norm(t.float(), p=2) for t in tensors]
return norms, None
per_tensor_result = torch.stack(per_tensor_norms)
else:
total_norm_sq = sum(flag_gems.sum(flag_gems.pow_func(t.float(), 2)) for t in tensors)
total_norm = flag_gems.sqrt(total_norm_sq)
return total_norm, None
per_tensor_result = torch.tensor(0.0, device=device)

return total_norm, per_tensor_result

def multi_tensor_scale_fl(chunk_size, noop_flag, tensor_lists, scale):
def multi_tensor_scale_fl(
_chunk_size: int,
noop_flag: torch.Tensor,
tensor_lists: List[List[torch.Tensor]],
scale: float,
) -> None:
if noop_flag.item() != 0:
return

for src, dst in zip(tensor_lists[0], tensor_lists[1]):
flag_gems.copy_(dst, src * scale)
# Check for inf/nan (matches CUDA behavior for AMP gradient scaling)
if not torch.isfinite(src).all():
noop_flag.fill_(1)
flag_gems.copy_(dst, flag_gems.mul(src, scale))
Loading
Loading