Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of CUDA-ised transform #602

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
69f2f24
Allow generation of CUFFT plans
jackleland Oct 23, 2024
5db4d18
Working CUFFT implementation for batched and serial forward transforms
jackleland Oct 31, 2024
f51fb40
Remove unnecessary reference to CUDA extension
jackleland Nov 11, 2024
c43d99d
Add additional methods for reverse transform
jackleland Nov 11, 2024
7cfc88f
First working version of cuda-ised inverse legendre transform
jackleland Nov 25, 2024
9b3dafb
Tidy up legendre inverse kernel
jackleland Nov 25, 2024
9c522f3
Loosen type restriction on LowerTriangularArray in _legendre!
jackleland Nov 25, 2024
aa78bd2
Merge branch 'main' into jl/fourier_cuda
milankl Nov 25, 2024
f4e99b0
WIP: lm_range via ij2k
milankl Nov 25, 2024
2d3d6c0
use ArrayType_ in SpectralTransform generator
milankl Nov 25, 2024
49afcd2
add SpeedyWeather. to access LTMs
milankl Nov 25, 2024
d57e0a0
update changelog
milankl Nov 25, 2024
e08cf04
Fix minor typos
jackleland Nov 25, 2024
a5c8d5e
Fix get_lm_range to match the original loop indexing
jackleland Nov 26, 2024
a827e76
Remove unused reference to LinearAlgebra
jackleland Dec 10, 2024
5e1c6c9
Remove reference to gpu.jl
jackleland Dec 10, 2024
df6824c
Merge branch 'main' into jl/fourier_cuda
jackleland Dec 10, 2024
03a12d0
Merge branch 'main' into jl/fourier_cuda
jackleland Dec 16, 2024
c06cc34
Update legendre kernel to use kjm indices stored as a matrix in Spect…
jackleland Dec 16, 2024
8d9c049
Add unit tests for components of transform!
jackleland Dec 16, 2024
52e4f8f
Fix tolerance on legendre kernel testing for Float64
jackleland Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- CUDA-ised spectral transform [#602](https://github.com/SpeedyWeather/SpeedyWeather.jl/pull/602)
- ConvectiveHeating implemented [#639](https://github.com/SpeedyWeather/SpeedyWeather.jl/pull/639)
- Number format flexibility with set! [#634](https://github.com/SpeedyWeather/SpeedyWeather.jl/pull/634)
- Forcing/drag for primitive models [#635](https://github.com/SpeedyWeather/SpeedyWeather.jl/pull/635)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module SpeedyWeatherCUDAExt

using SpeedyWeather
import CUDA: CUDA, CUDAKernels, CuArray
import CUDA: CUDA, CUDAKernels, CuArray, CUFFT
import AbstractFFTs
using DocStringExtensions

# for RingGrids and LowerTriangularMatrices:
Expand All @@ -11,7 +12,6 @@ LowerTriangularMatrices.nonparametric_type(::Type{<:CuArray}) = CuArray

SpeedyWeather.default_array_type(::Type{GPU}) = CuArray


# DEVICE SETUP FOR CUDA

"""$(TYPEDSIGNATURES)
Expand All @@ -30,4 +30,7 @@ SpeedyWeather.DeviceArray(::GPU, x) = Adapt.adapt(CuArray, x)
Returns a `CuArray` when `device<:GPU` is used. Doesn't uses `adapt`, therefore always returns CuArray."""
SpeedyWeather.DeviceArrayNotAdapt(::GPU, x) = CuArray(x)

include("fourier.jl")
include("legendre.jl")

end # module
202 changes: 202 additions & 0 deletions ext/SpeedyWeatherCUDAExt/fourier.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@

# Override FFT package deciding function
SpeedyTransforms.which_FFT_package(::Type{<:CuArray{<:AbstractFloat}}) = CUFFT

"""$(TYPEDSIGNATURES)
Util function to generate FFT plans based on the array type of the fake Grid
data provided. Uses indexing as we seemingly can't use views for FFT planning
with CUFFT."""
function SpeedyTransforms.plan_FFTs!(
rfft_plans::Vector{AbstractFFTs.Plan},
brfft_plans::Vector{AbstractFFTs.Plan},
rfft_plans_1D::Vector{AbstractFFTs.Plan},
brfft_plans_1D::Vector{AbstractFFTs.Plan},
fake_grid_data::AbstractGridArray{NF, N, <:CuArray{NF}},
scratch_memory_north::CuArray{Complex{NF}},
rings::AbstractArray,
nlons::Vector{<:Int}
) where {NF<:AbstractFloat, N}
# Determine which FFT package to use (currently either FFTW or GenericFFT)
FFT_package = SpeedyTransforms.which_FFT_package(CuArray{NF})

# For each ring generate an FFT plan (for all layers and for a single layer)
for (j, nlon) in enumerate(nlons)
real_matrix_input = fake_grid_data.data[rings[j], :]
complex_matrix_input = scratch_memory_north[1:nlon÷2 + 1, :, j]
real_vector_input = fake_grid_data.data[rings[j], 1]
complex_vector_input = scratch_memory_north[1:nlon÷2 + 1, 1, j]

rfft_plans[j] = FFT_package.plan_rfft(real_matrix_input, 1)
brfft_plans[j] = FFT_package.plan_brfft(complex_matrix_input, nlon, 1)
rfft_plans_1D[j] = FFT_package.plan_rfft(real_vector_input, 1)
brfft_plans_1D[j] = FFT_package.plan_brfft(complex_vector_input, nlon, 1)
end
milankl marked this conversation as resolved.
Show resolved Hide resolved

return rfft_plans, brfft_plans, rfft_plans_1D, brfft_plans_1D
end


function SpeedyTransforms._fourier_batched!( # GRID TO SPECTRAL
f_north::CuArray{<:Complex, 3}, # Fourier-transformed output
f_south::CuArray{<:Complex, 3}, # and for southern latitudes
grids::AbstractGridArray{NF, N, <:CuArray}, # gridded input
S::SpectralTransform, # precomputed transform
milankl marked this conversation as resolved.
Show resolved Hide resolved
) where {NF<:AbstractFloat, N}
(; nlat, nlons, nlat_half) = S # dimensions
(; rfft_plans) = S # pre-planned transforms
nlayers = size(grids, 2)

@boundscheck SpeedyTransforms.ismatching(S, grids) || throw(DimensionMismatch(S, grids))
@boundscheck nlayers == S.nlayers || throw(DimensionMismatch(S, grids))
@boundscheck size(f_north) == size(f_south) == (S.nfreq_max, S.nlayers, nlat_half) || throw(DimensionMismatch(S, grids))

rings = eachring(grids) # precomputed ring indices
@inbounds for j_north in 1:nlat_half # symmetry: loop over northern latitudes only
j = j_north # symmetric index / ring-away from pole index
j_south = nlat - j_north + 1 # corresponding southern latitude index
nlon = nlons[j] # number of longitudes on this ring
nfreq = nlon÷2 + 1 # linear max Fourier frequency wrt to nlon
not_equator = j_north != j_south # is the latitude ring not on equator?

rfft_plan = rfft_plans[j] # FFT planned wrt nlon on ring
ilons = rings[j_north] # in-ring indices northern ring

# FOURIER TRANSFORM in zonal direction, northern latitude
view(f_north, 1:nfreq, 1:nlayers, j) .= rfft_plan * grids.data[ilons, :]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vchuravy is there a way to apply a cuFFT plan such that the result is directly written into some pre-allocated array? Instead of the allocation here on the right side and then writing it into the view of an array on the left?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jackleland because at the moment it sounds we would just need an rfft!/brfft! method that calls this line, either written with LinearAlgebra.mul! for CPUs or with something else for GPUs, so the _fourier_batched! function is actually the same and we wouldn't need any code duplication here, just a new method for rfft!, brfft!


# This is faster but doesn't seem to work for greater than 1D, seems to
# hang.
# view(f_north, 1:nfreq, 1:nlayers, j) .= rfft_plan * view(grids.data, ilons, :)

# and southern latitude if not on Equator
ilons = rings[j_south] # in-ring indices southern ring
if not_equator # skip FFT, redundant because north did that latitude already
view(f_south, 1:nfreq, 1:nlayers, j) .= rfft_plan * grids.data[ilons, :]
else
fill!(f_south[1:nfreq, 1:nlayers, j], 0)
end
end

end

function SpeedyTransforms._fourier_serial!( # GRID TO SPECTRAL
f_north::CuArray{<:Complex, 3}, # Fourier-transformed output
f_south::CuArray{<:Complex, 3}, # and for southern latitudes
grids::AbstractGridArray{NF, N, <:CuArray}, # gridded input
S::SpectralTransform, # precomputed transform
) where {NF<:AbstractFloat, N}
(; nlat, nlons, nlat_half) = S # dimensions
rfft_plans = S.rfft_plans_1D # pre-planned transforms
nlayers = size(grids, 2) # number of vertical layers

@boundscheck SpeedyTransforms.ismatching(S, grids) || throw(DimensionMismatch(S, grids))
@boundscheck nlayers <= S.nlayers || throw(DimensionMismatch(S, grids))
@boundscheck size(f_north) == size(f_south) == (S.nfreq_max, S.nlayers, nlat_half) || throw(DimensionMismatch(S, grids))

rings = eachring(grids) # precomputed ring indices
@inbounds for (k, k_grid) in zip(1:nlayers, eachgrid(grids))
for j_north in 1:nlat_half # symmetry: loop over northern latitudes only
j = j_north # symmetric index / ring-away from pole index
j_south = nlat - j_north + 1 # southern latitude index
nlon = nlons[j] # number of longitudes on this ring (north or south)
nfreq = nlon÷2 + 1 # linear max Fourier frequency wrt to nlon
not_equator = j_north != j_south # is the latitude ring not on equator?

# FOURIER TRANSFORM in zonal direction
rfft_plan = rfft_plans[j] # FFT planned wrt nlon on ring
ilons = rings[j_north] # in-ring indices northern ring
view(f_north, 1:nfreq, k, j) .= rfft_plan * view(grids.data, ilons, k_grid)

# southern latitude, don't call redundant 2nd fft if ring is on equator
ilons = rings[j_south] # in-ring indices southern ring
if not_equator
view(f_south, 1:nfreq, k, j) .= rfft_plan * view(grids.data, ilons, k_grid) # perform FFT
else
fill!(view(f_south, 1:nfreq, k, j), 0)
end
end
end
end


"""$(TYPEDSIGNATURES)
Inverse fast Fourier transform (spectral to grid) of Legendre-transformed inputs `g_north` and `g_south`
to be stored in `grids`. Not to be called directly, use `transform!` instead."""
function SpeedyTransforms._fourier_batched!( # SPECTRAL TO GRID
grids::AbstractGridArray{NF, N, <:CuArray}, # gridded output
g_north::CuArray{<:Complex, 3}, # Legendre-transformed input
g_south::CuArray{<:Complex, 3}, # and for southern latitudes
S::SpectralTransform, # precomputed transform
) where {NF<:AbstractFloat, N}
(; nlat, nlons, nlat_half) = S # dimensions
(; brfft_plans) = S # pre-planned transforms
nlayers = size(grids, 2) # number of vertical layers

@boundscheck SpeedyTransforms.ismatching(S, grids) || throw(DimensionMismatch(S, grids))
@boundscheck nlayers == S.nlayers || throw(DimensionMismatch(S, grids)) # otherwise FFTW complains
@boundscheck size(g_north) == size(g_south) == (S.nfreq_max, S.nlayers, nlat_half) || throw(DimensionMismatch(S, grids))

rings = eachring(grids) # precomputed ring indices
@inbounds for j_north in 1:nlat_half # symmetry: loop over northern latitudes only
j = j_north # symmetric index / ring-away from pole index
j_south = nlat - j_north + 1 # southern latitude index
nlon = nlons[j] # number of longitudes on this ring (north or south)
nfreq = nlon÷2 + 1 # linear max Fourier frequency wrt to nlon
not_equator = j_north != j_south # is the latitude ring not on equator?

brfft_plan = brfft_plans[j] # FFT planned wrt nlon on ring
ilons = rings[j_north] # in-ring indices northern ring

# PERFORM FFT, inverse complex to real, hence brfft
view(grids.data, ilons, :) .= brfft_plan * g_north[1:nfreq, 1:nlayers, j]

# southern latitude, don't call redundant 2nd FFT if ring is on equator
ilons = rings[j_south] # in-ring indices southern ring
if not_equator
view(grids.data, ilons, :) .= brfft_plan * g_south[1:nfreq, 1:nlayers, j]
end
end
end

"""$(TYPEDSIGNATURES)
(Inverse) Fast Fourier transform (spectral to grid) of Legendre-transformed inputs `g_north` and `g_south`
to be stored in `grids`. Serial version that does not require the number of vertical layers to be the same
as precomputed in `S`. Not to be called directly, use `transform!` instead."""
function SpeedyTransforms._fourier_serial!( # SPECTRAL TO GRID
grids::AbstractGridArray{NF, N, <:CuArray}, # gridded output
g_north::CuArray{<:Complex, 3}, # Legendre-transformed input
g_south::CuArray{<:Complex, 3}, # and for southern latitudes
S::SpectralTransform, # precomputed transform
) where {NF<:AbstractFloat, N}
(; nlat, nlons, nlat_half) = S # dimensions
brfft_plans = S.brfft_plans_1D # pre-planned transforms
nlayers = size(grids, 2) # number of vertical layers

@boundscheck SpeedyTransforms.ismatching(S, grids) || throw(DimensionMismatch(S, grids))
@boundscheck nlayers <= S.nlayers || throw(DimensionMismatch(S, grids)) # otherwise FFTW complains
@boundscheck size(g_north) == size(g_south) == (S.nfreq_max, S.nlayers, nlat_half) || throw(DimensionMismatch(S, grids))

rings = eachring(grids) # precomputed ring indices
@inbounds for (k, k_grid) in zip(1:nlayers, eachgrid(grids))
for j_north in 1:nlat_half # symmetry: loop over northern latitudes only
j = j_north # symmetric index / ring-away from pole index
j_south = nlat - j_north + 1 # southern latitude index
nlon = nlons[j] # number of longitudes on this ring (north or south)
nfreq = nlon÷2 + 1 # linear max Fourier frequency wrt to nlon
not_equator = j_north != j_south # is the latitude ring not on equator?

brfft_plan = brfft_plans[j] # FFT planned wrt nlon on ring
ilons = rings[j_north] # in-ring indices northern ring

gn = view(g_north, 1:nfreq, k, j) # data on northern ring, vertical layer k
out = view(grids.data, ilons, k_grid) # view on scratch memory to store transformed data
out .= brfft_plan * gn # perform FFT

# southern latitude, don't call redundant 2nd fft if ring is on equator
gs = view(g_south, 1:nfreq, k, j) # data on southern ring, vertical layer k
ilons = rings[j_south] # in-ring indices southern ring
out = view(grids.data, ilons, k_grid) # data on southern ring, vertical layer k
not_equator && out .= brfft_plan * gs
end
end
end
131 changes: 131 additions & 0 deletions ext/SpeedyWeatherCUDAExt/legendre.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# convert i, j indices of a matrix (here 0-based l,m though...) to a single 1-based running index
import SpeedyWeather.LowerTriangularMatrices: ij2k

# range of the running indices lm in a l-column (degrees of spherical harmonics)
# given the column index m (order of harmonics)
get_lm_range(m, lmax) = ij2k(2*m - 1, m, lmax):ij2k(lmax+m, m, lmax)

# (inverse) legendre transform kernel, called from _legendre!
function inverse_legendre_kernel!(
g_north, # Scratch storage for legendre coefficients
g_south, # before fft
specs_data, # Data passed from spectral grid
legendre_polynomials_data, # Pre-calculated Legendre coefficients
lmax, # Max l-value, from SpectralTransform struct
lon_offsets, # Longitude
kjm_indices
)
tid = (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + CUDA.threadIdx().x

if tid <= size(kjm_indices, 1)
# Unpack indices from precomputed kjm_indices using single thread index
k = kjm_indices[tid, 1]
j = kjm_indices[tid, 2]
m = kjm_indices[tid, 3]

# are m, lmax 0-based here or 1-based?
lm_range = get_lm_range(m, lmax) # assumes 1-based
Comment on lines +26 to +27
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be adapted (+comment maybe!) for 0-based m, lmax


# view on lower triangular column, but batched in vertical
spec_view = view(specs_data, lm_range, :)
legendre_view = view(legendre_polynomials_data, lm_range, j)

# dot product but split into even and odd harmonics on the fly as this
# is how the previous implementation was enacted
lmax_range = length(lm_range) # number of degrees at order m, lmax-m
isoddlmax = isodd(lmax_range)
lmax_even = lmax_range - isoddlmax # if odd do last odd element after the loop

# Got rid of bounds check, potentially unsafe?
# @boundscheck size(north) == size(south) || throw(DimensionMismatch)
# @boundscheck size(spec_view, 1) == length(legendre_view) || throw(DimensionMismatch)
# @boundscheck size(spec_view, 2) <= length(north) || throw(DimensionMismatch)

# "even" and "odd" coined with 0-based indexing, i.e. the even l=0 mode is 1st element
even_k = zero(eltype(g_south)) # dot product with elements 1, 3, 5, ...
odd_k = zero(eltype(g_north)) # dot prodcut with elements 2, 4, 6, ...

# Switched to while loop as more performant from inside a Kernel
l = 1
while l < lmax_even # dot product in pairs for contiguous memory access
even_k += spec_view[l, k] * legendre_view[l]
odd_k += spec_view[l+1, k] * legendre_view[l+1]
l += 2
end

# now do the last row if lmax is odd
even_k += spec_view[end, k] * (isoddlmax * legendre_view[end])
north = even_k + odd_k
south = even_k - odd_k

# CORRECT FOR LONGITUDE OFFSETTS (if grid points don't start at 0°E)
o = lon_offsets[m, j] # rotation through multiplication with complex unit vector

g_north[m, k, j] += o * north
g_south[m, k, j] += o * south
end
return
end


"""$(TYPEDSIGNATURES)
Inverse Legendre transform, adapted for CUDA and batched across j (lattitude),
k (vertical layers) and m (spherical harmonic order). Not to be used directly,
but called from transform! with CuArrays."""
function SpeedyTransforms._legendre!(
g_north::CuArray{<:Complex, 3}, # Legendre-transformed output, northern latitudes
g_south::CuArray{<:Complex, 3}, # and southern latitudes
specs::LowerTriangularArray, # input: spherical harmonic coefficients
S::SpectralTransform, # precomputed transform
unscale_coslat::Bool = false, # unscale by cosine of latitude on the fly?
)
(; nlat_half) = S # dimensions
(; lmax, mmax ) = S # 0-based max degree l, order m of spherical harmonics
(; legendre_polynomials) = S # precomputed Legendre polynomials
(; jm_index_size, kjm_indices ) = S # kjm loop indices precomputed for threads
(; coslat⁻¹, lon_offsets ) = S
# NOTE: this comes out as a range, not an integer
nlayers = axes(specs, 2) # get number of layers of specs for fewer layers than precomputed in S

@boundscheck SpeedyTransforms.ismatching(S, specs) || throw(DimensionMismatch(S, specs))
@boundscheck size(g_north) == size(g_south) == (S.nfreq_max, S.nlayers, nlat_half) || throw(DimensionMismatch(S, specs))
# reduced_kjm = kjm_indices[1:(nlayers.stop * jm_index_size), :] # get the reduced kjm indices

# @show reduced_kjm
g_north .= 0
g_south .= 0

# INVERSE LEGENDRE TRANSFORM by looping over wavenumbers l, m and layer k
kernel = CUDA.@cuda launch=false inverse_legendre_kernel!(
g_north,
g_south,
specs.data,
legendre_polynomials.data,
lmax,
lon_offsets,
kjm_indices
)
config = CUDA.launch_configuration(kernel.fun)
threads = min(size(kjm_indices, 1), config.threads)
blocks = cld(size(kjm_indices, 1), threads)

# actually launch kernel!
kernel(
g_north,
g_south,
specs.data,
legendre_polynomials.data,
lmax,
lon_offsets,
kjm_indices;
threads,
blocks
)

if unscale_coslat
@inbounds for j in 1:nlat_half # symmetry: loop over northern latitudes only
g_north[:, nlayers, j] .*= coslat⁻¹[j] # scale in place
g_south[:, nlayers, j] .*= coslat⁻¹[j]
end
end
end
1 change: 1 addition & 0 deletions src/SpeedyTransforms/SpeedyTransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import FFTW
import GenericFFT
import LinearAlgebra
import Primes
import Adapt: adapt

# SPEEDYWEATHER MODULES
using ..LowerTriangularMatrices
Expand Down
Loading
Loading