Skip to content

Conversation

@sym-bot
Copy link

@sym-bot sym-bot commented Jan 31, 2026

Problem

External attention backends (flash_attn, xformers, sageattention, etc.) may be installed but fail to import at runtime due to ABI mismatches. For example, when flash_attn is compiled against PyTorch 2.4 but used with PyTorch 2.8, the import fails with:

OSError: .../flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab

The current code uses importlib.util.find_spec() to check if packages exist, but this only verifies the package is installed—not that it can actually be imported. When the import fails, diffusers crashes instead of falling back to native PyTorch attention.

Solution

Wrap all external attention backend imports in try-except blocks that catch ImportError and OSError. On failure:

  1. Log a warning message explaining the issue
  2. Set the corresponding _CAN_USE_* flag to False
  3. Set the imported functions to None

This allows diffusers to gracefully degrade to PyTorch's native SDPA (scaled_dot_product_attention) instead of crashing.

Affected backends

  • flash_attn (Flash Attention)
  • flash_attn_3 (Flash Attention 3)
  • aiter (AMD Instinct)
  • sageattention (SageAttention)
  • flex_attention (PyTorch Flex Attention)
  • torch_npu (Huawei NPU)
  • torch_xla (TPU/XLA)
  • xformers (Meta xFormers)

Testing

Tested with PyTorch 2.8.0 and flash_attn 2.7.4.post1 (compiled for PyTorch 2.4).

  • Before: crashes on from diffusers import ... with undefined symbol error
  • After: logs warning and uses native attention successfully

Example warning output

WARNING:diffusers.models.attention_dispatch:flash_attn is installed but failed to import: .../flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab. Falling back to native PyTorch attention.

## Problem

External attention backends (flash_attn, xformers, sageattention, etc.) may be
installed but fail to import at runtime due to ABI mismatches. For example,
when `flash_attn` is compiled against PyTorch 2.4 but used with PyTorch 2.8,
the import fails with:

```
OSError: .../flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab
```

The current code uses `importlib.util.find_spec()` to check if packages exist,
but this only verifies the package is installed—not that it can actually be
imported. When the import fails, diffusers crashes instead of falling back to
native PyTorch attention.

## Solution

Wrap all external attention backend imports in try-except blocks that catch
`ImportError` and `OSError`. On failure:
1. Log a warning message explaining the issue
2. Set the corresponding `_CAN_USE_*` flag to `False`
3. Set the imported functions to `None`

This allows diffusers to gracefully degrade to PyTorch's native SDPA
(scaled_dot_product_attention) instead of crashing.

## Affected backends

- flash_attn (Flash Attention)
- flash_attn_3 (Flash Attention 3)
- aiter (AMD Instinct)
- sageattention (SageAttention)
- flex_attention (PyTorch Flex Attention)
- torch_npu (Huawei NPU)
- torch_xla (TPU/XLA)
- xformers (Meta xFormers)

## Testing

Tested with PyTorch 2.8.0 and flash_attn 2.7.4.post1 (compiled for PyTorch 2.4).
Before: crashes on import. After: logs warning and uses native attention.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant