Skip to content

GeLU kernel fails on non-contiguous tensors (ModernBERT) #106

@AmitMY

Description

@AmitMY

Might be the same as #76

Summary

The GeLU kernel from kernels-community/activation produces incorrect outputs for non-contiguous tensors. This causes completely wrong predictions when used with ModernBERT, which uses .chunk() to create non-contiguous activation inputs.

Root Cause

ModernBERT's MLP uses a gated architecture:

def forward(self, hidden_states):
    input, gate = self.Wi(hidden_states).chunk(2, dim=-1)  # Creates non-contiguous tensors!
    return self.Wo(self.drop(self.act(input) * gate))

The .chunk() operation produces non-contiguous tensors. The GeLU kernel's ops.gelu function doesn't handle these correctly.

Proof

# Non-contiguous input (from chunk)
act_input, gate = wi_output.chunk(2, dim=-1)
print(act_input.is_contiguous())  # False
print(act_input.stride())          # (2304, 1) - not contiguous!

# Kernel fails on non-contiguous
out_kern = gelu_kernel(act_input)
out_orig = gelu_original(act_input)
print((out_orig - out_kern).abs().max())  # 17.5 ❌

# Kernel works on contiguous
out_kern = gelu_kernel(act_input.contiguous())
out_orig = gelu_original(act_input.contiguous())
print((out_orig - out_kern).abs().max())  # 0.0 ✓

Environment

  • Model: answerdotai/ModernBERT-base
  • Kernel: GeLU from kernels-community/activation
  • PyTorch: 2.x with CUDA

Minimal Reproduction

from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from kernels import kernelize, Mode

model_id = "answerdotai/ModernBERT-base"
device = "cuda"

tokenizer = AutoTokenizer.from_pretrained(model_id)
text = "The capital of France is [MASK]."
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

# Without kernelize - CORRECT
model = AutoModelForMaskedLM.from_pretrained(model_id).to(device)
outputs = model(**inputs)
masked_index = inputs["input_ids"][0].tolist().index(tokenizer.mask_token_id)
print("Without kernelize:", tokenizer.decode(outputs.logits[0, masked_index].argmax()))
# Output: Paris ✓

# With kernelize - WRONG (due to non-contiguous tensor bug)
model2 = AutoModelForMaskedLM.from_pretrained(model_id).to(device)
model2 = kernelize(model2, mode=Mode.INFERENCE, device=device)
outputs2 = model2(**inputs)
print("With kernelize:", tokenizer.decode(outputs2.logits[0, masked_index].argmax()))
# Output: required ✗

Impact

Metric Value
Max logit diff 47.0
Mean logit diff 5.6
Prediction Completely wrong

Affected Models

Any model that passes non-contiguous tensors to GELUActivation, including:

  • ModernBERT (uses .chunk() in gated MLP)
  • Potentially other gated architectures

Recommendation

  • Do not use kernelize on ModernBERT until the kernel is fixed
  • The GeLU kernel should either:
    1. Call .contiguous() on input before processing, or
    2. Handle strided tensors correctly in ops.gelu

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions