Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@
# build after cloning in directoy torch_harmonics via
# docker build . -t torch_harmonics

FROM nvcr.io/nvidia/pytorch:24.12-py3
FROM nvcr.io/nvidia/pytorch:25.12-py3

# we need this for tests
RUN pip install parameterized

# The custom CUDA extension does not suppport architerctures < 7.0
ENV FORCE_CUDA_EXTENSION=1
ENV TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
ENV TORCH_HARMONICS_ENABLE_OPENMP=1
ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 9.0 10.0+PTX"
COPY . /workspace/torch_harmonics
RUN cd /workspace/torch_harmonics && pip install --no-build-isolation .

8 changes: 4 additions & 4 deletions Dockerfile.examples
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# build after cloning in directoy torch_harmonics via
# docker build . -t torch_harmonics

FROM nvcr.io/nvidia/pytorch:24.12-py3
FROM nvcr.io/nvidia/pytorch:25.12-py3

# we need this for tests
RUN pip install parameterized
Expand All @@ -48,14 +48,14 @@ RUN pip install h5py
RUN cd /opt && git clone https://github.com/SHI-Labs/NATTEN natten && \
cd natten && \
make WITH_CUDA=1 \
CUDA_ARCH="7.0;7.2;7.5;8.0;8.6;8.7;9.0" \
WORKERS=4
CUDA_ARCH="8.0;8.6;8.7;9.0;10.0" \
WORKERS=4

# install torch harmonics
COPY . /workspace/torch_harmonics

# The custom CUDA extension does not suppport architerctures < 7.0
ENV FORCE_CUDA_EXTENSION=1
ENV TORCH_HARMONICS_ENABLE_OPENMP=1
ENV TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.7 9.0 10.0+PTX"
RUN cd /workspace/torch_harmonics && pip install --no-build-isolation .
146 changes: 95 additions & 51 deletions tests/test_convolution.py

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#

import os
from packaging import version

import torch
import torch.distributed as dist
import torch_harmonics.distributed as thd
Expand All @@ -41,6 +43,20 @@ def set_seed(seed=333):
return


def disable_tf32():
# the api for this was changed lately in pytorch
if torch.cuda.is_available():
if version.parse(torch.__version__) >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = "ieee"
torch.backends.cudnn.fp32_precision = "ieee"
torch.backends.cudnn.conv.fp32_precision = "ieee"
torch.backends.cudnn.rnn.fp32_precision = "ieee"
else:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
return


def setup_distributed_context(ctx):
ctx.world_rank = int(os.getenv("WORLD_RANK", 0))
ctx.grid_size_h = int(os.getenv("GRID_H", 1))
Expand Down
39 changes: 25 additions & 14 deletions torch_harmonics/disco/_disco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#

from typing import Optional
import math

import torch
from disco_helpers import optimized_kernels_is_available
Expand Down Expand Up @@ -183,17 +182,22 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
# add a dummy dimension for nkernel and move the batch and channel dims to the end
x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1)
x = x.expand(kernel_size, -1, -1, -1)
xtype = x.dtype
x = x.to(torch.float32).contiguous()

y = torch.zeros(nlon_out, kernel_size, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype)

for pout in range(nlon_out):
# sparse contraction with psi
y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1))
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication
x = torch.roll(x, -pscale, dims=2)
with torch.amp.autocast(device_type=x.device.type, enabled=False):
for pout in range(nlon_out):
# sparse contraction with psi
y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1))
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication
x = torch.roll(x, -pscale, dims=2)

y = y.to(xtype)

# reshape y back to expose the correct dimensions
y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out)
y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out).contiguous()

return y

Expand All @@ -218,18 +222,25 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
# we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel
x_ext[:, :, ::pscale, :] = x[...]

xtype = x_ext.dtype
x_ext = x_ext.to(torch.float32).contiguous()

# create output tensor
y = torch.zeros(kernel_size, nlon_out, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype)

for pout in range(nlon_out):
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication
# TODO: double-check why this has to happen first
x_ext = torch.roll(x_ext, -1, dims=2)
# sparse contraction with the modified psi
y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1))
with torch.amp.autocast(device_type=x.device.type, enabled=False):
for pout in range(nlon_out):
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication
# TODO: double-check why this has to happen first
x_ext = torch.roll(x_ext, -1, dims=2)
# sparse contraction with the modified psi
y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1))

# sum over the kernel dimension and reshape to the correct output size
y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out).contiguous()
y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out)

# convert datatype back to input type
y = y.to(xtype).contiguous()

return y

Loading