torch-harmonics-kernels
Prebuilt CUDA kernels for the torch-harmonics
DISCO (DIscrete-COntinuous) spherical convolution and spherical neighborhood attention
operators, packaged for the Hugging Face kernels
ecosystem.
Quickstart
from kernels import get_kernel
mod = get_kernel("kashif/torch-harmonics-kernels", version=1)
# DISCO forward
y = mod.disco_forward(
inp, roff_idx, ker_idx, row_idx, col_idx, vals,
kernel_size, nlat_out, nlon_out,
)
# Spherical neighborhood attention forward
y = mod.attention_forward(
kx, vx, qy, quad_weights, col_idx, row_off,
nlon_in, nlat_out, nlon_out,
)
Functions
DISCO convolution
disco_forward(
inp, # [B, C, Hi, Wi] float32/float64, contiguous CUDA
roff_idx, # [nrows+1] int64, CSR row offsets
ker_idx, # [nnz] int64, kernel-bin index per non-zero
row_idx, # [nnz] int64, output-latitude row per non-zero
col_idx, # [nnz] int64, flattened input position (h * Wi + w)
vals, # [nnz] float32/float64, sparse weights
kernel_size: int, # K (number of basis functions)
nlat_out: int, # Ho
nlon_out: int, # Wo
) -> torch.Tensor # [B, C, K, Ho, Wo]
disco_backward(
grad, # [B, C, K, Ho, Wo] gradient w.r.t. forward output
roff_idx, ker_idx, row_idx, col_idx, vals,
kernel_size, nlat_out, nlon_out,
) -> torch.Tensor # [B, C, Ho, Wo] gradient w.r.t. input
Constraint: for disco_forward, Wi must be an integer multiple of Wo;
for disco_backward, Wo must be an integer multiple of Wi.
Spherical neighborhood attention
attention_forward(
kx, # [B, C, nlat_in, nlon_in] float32, key tensor
vx, # [B, C, nlat_in, nlon_in] float32, value tensor
qy, # [B, C, nlat_out, nlon_out] float32, query tensor
quad_weights, # [nlat_in] float32, quadrature weights
col_idx, # [nnz] int64, flattened input position
row_off, # [nlat_out+1] int64, CSR row offsets
nlon_in: int,
nlat_out: int,
nlon_out: int,
) -> torch.Tensor # [B, C, nlat_out, nlon_out]
attention_backward(
kx, vx, qy,
dy, # [B, C, nlat_out, nlon_out] float32, output gradient
quad_weights, col_idx, row_off,
nlon_in, nlat_out, nlon_out,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] # (dkx, dvx, dqy)
Constraint: nlon_in must be an integer multiple of nlon_out (the kernel
uses pscale = nlon_in // nlon_out for downsampling).
Helpers
preprocess_psi(kernel_size, nlat_out, ker_idx, row_idx, col_idx, vals) -> torch.Tensor
# Sorts the COO sparse matrix in place by kernel index and returns the CSR
# row-offset tensor expected by disco_forward / disco_backward.
cuda_kernels_is_available() -> bool # always True for this build
optimized_kernels_is_available() -> bool
Index conventions
The sparse (ker_idx, row_idx, col_idx, vals) tensors describe the discretised
basis functions of the spherical convolution / neighborhood. They are computed
once by torch-harmonics from the input/output grids and the basis definition,
then handed to these kernels at every forward/backward call. If you are building
on top of torch-harmonics, you can pull these tensors directly from the
DiscreteContinuousConvS2 / NeighborhoodAttentionS2 module buffers
(psi_roff_idx, psi_ker_idx, psi_row_idx, psi_col_idx, psi_vals).
Build variants
This release ships only the variant that matches the typical Hugging Face runtime environment:
torch210-cxx11-cu128-x86_64-linux
For other torch / CUDA combinations, build from source with kernel-builder.
Build from source
git clone https://huggingface.co/kernels/kashif/torch-harmonics-kernels
cd torch-harmonics-kernels
kernel-builder build --variant torch210-cxx11-cu128-x86_64-linux --cores 8 -L
pytest tests/test_local_kernel.py -q
What this package contains
- Vendored CUDA sources from the upstream
torch-harmonicsoptimized/kernels_cuda/tree, exposed as the four operators above pluspreprocess_psi. - Local optimisations on top of upstream:
- Cached
sortRows— the CSR row-length sort runs once per uniquepsi_row_offtensor instead of every forward/backward call. - Cached
getPtxver— eliminates the synchronisingcudaFuncGetAttributescall from the per-call permutation path. - Single-
expfonline-softmax inner loop (vs the upstream two-expfform) in both forward and backward.
- Cached
- Same
pscale = nlon_in / nlon_outdownsampling support as upstream, with matchingTORCH_CHECKassertions.
License
BSD-3-Clause, inherited from torch-harmonics.
- Downloads last month
- 3
- Torch
- 2.10
- OS
- linux
- Arch
- x86_64




