-
Notifications
You must be signed in to change notification settings - Fork 747
Description
Hi, and thanks for releasing Boltz.
I’ve been trying to set up Boltz 2.0.3 with GPU support on a cluster node and ran into a hard version mismatch between boltz → trifast → torch and the currently available CUDA wheels.
Environment
OS: Linux (cluster node, containerized)
Python: 3.12
GPU: NVIDIA L4
nvidia-smi reports:
Driver Version: 580.105.08
CUDA Version: 13.0
Package manager: micromamba for envs, pip for Python packages
Boltz: 2.0.3 (from PyPI)
What I did
Created a fresh env with Python 3.12.
Installed a CUDA-enabled PyTorch + torchvision from the official PyTorch cu121 index:
pip install "torch==2.5.1+cu121" "torchvision==0.20.1+cu121"
--index-url https://download.pytorch.org/whl/cu121
Installed boltz==2.0.3 and its dependencies (resolving version issues with numpy, numba, scipy, etc. so that:
torch==2.5.1+cu121
torchvision==0.20.1+cu121
boltz==2.0.3
numba==0.61.0
numpy==1.26.4
and other pinned deps (hydra-core, pytorch-lightning, rdkit, scikit-learn, wandb, trifast==0.1.13, …)
Prepared a YAML input using the documented Boltz schema (version: 1, sequences: - protein: {...} - ligand: {...}, etc.) and ran:
boltz predict XXX.yaml
--out_dir ./outputs/XXX
--use_msa_server
After fixing several missing Python deps (jaxtyping, wandb deps, etc.), the model reached the forward pass and then crashed with:
ImportError: cannot import name 'wrap_triton' from 'torch.library'
File ".../boltz/model/layers/triangular_attention/primitives.py", line 685, in _trifast_attn
from trifast import triangle_attention
File ".../trifast/torch.py", line 6, in
from torch.library import wrap_triton, triton_op
ImportError: cannot import name 'wrap_triton' from 'torch.library'
Analysis:
boltz==2.0.3 pins trifast==0.1.13.
trifast 0.1.13 declares torch>=2.6.0 and uses new Torch APIs (torch.library.wrap_triton, triton_op) that only exist in torch ≥ 2.6.
On the CUDA side, the current official cu121 wheels on https://download.pytorch.org/whl/cu121 only go up to:
torch==2.5.1+cu121
i.e. there is no torch>=2.6.0+cu121 wheel published yet.
This leads to a hard conflict:
To satisfy trifast==0.1.13, we need torch>=2.6.0.
To have a CUDA 12.1 binary, we are limited to torch<=2.5.1+cu121.
Installing torch==2.5.1+cu121 satisfies CUDA but breaks trifast at runtime with the wrap_triton ImportError.
Attempting to force an older trifast version is blocked by Boltz’s strict trifast==0.1.13 requirement.
On CPU (no CUDA), if I let pip install torch from PyPI only (CPU wheel), it can pick a ≥2.6.0 build and everything works; the wrap_triton symbol exists and boltz predict runs successfully (albeit slowly on CPU). So this appears to be specific to the GPU / CUDA wheel availability.
What I’m requesting
Would it be possible to:
Document a known-good, GPU-compatible version stack for Boltz 2.x (exact torch, CUDA, trifast versions), and how to obtain them, or
Relax/update the dependencies so that:
either trifast version and its torch requirement are compatible with currently available CUDA wheels, or
an option is exposed to disable trifast at runtime (e.g., a --disable_trifast CLI flag or a config knob) and fall back to a pure-PyTorch attention implementation when wrap_triton isn’t available, or
Provide guidance on the recommended way to run Boltz-2 on GPU today, given that:
cu121 wheels stop at torch 2.5.1,
Boltz pins trifast==0.1.13,
and trifast 0.1.13 requires torch>=2.6.0.
Right now, the situation seems to be:
CPU-only: works (torch CPU ≥2.6.0 from PyPI, trifast happy).
GPU: blocked because no torch>=2.6.0+cu121 wheels exist yet, and trifast 0.1.13 can’t run on ≤2.5.1 (missing wrap_triton).
If there is an internal/tested combination that works on GPU (e.g., specific torch/cu12.x wheels or a different trifast version), I’d really appreciate having that documented, or having Boltz’s dependency pins adjusted accordingly.
Thanks a lot for any clarification and for the work on Boltz.