Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for CUFFT callbacks #75

Open
vchuravy opened this issue Nov 10, 2016 · 28 comments
Open

Support for CUFFT callbacks #75

vchuravy opened this issue Nov 10, 2016 · 28 comments
Labels
cuda kernels Stuff about writing CUDA kernels. enhancement New feature or request

Comments

@vchuravy
Copy link
Member

One interesting use-case is to pass Julia functions as callbacks to cufft
https://devblogs.nvidia.com/parallelforall/cuda-pro-tip-use-cufft-callbacks-custom-data-processing/

To register the callbacks with the cuFFT plan, the first step is to get the device function pointers from the device onto the host.

@maleadt
Copy link
Member

maleadt commented Nov 10, 2016

This will require quite some work, because we don't have symbol resolution other than calling functions from emit_invoke (ie. no function pointers). We'd need proper JIT symbol lookup, like fptr/jl_generate_fptr, somehow dealing with multiple contexts, the fact that API calls only work at runtime vs compile-time, etc along the way.

It's also not clear what exactly the pointer should be (CUfunction_t? or more likely cuModuleGetGlobal).

@maleadt maleadt transferred this issue from JuliaGPU/CUDAnative.jl May 27, 2020
@maleadt maleadt added cuda kernels Stuff about writing CUDA kernels. enhancement New feature or request labels May 27, 2020
@david-macmahon
Copy link
Contributor

In reply to your question in #614 about how important this feature is, I think the answer is that it very important in certain circumstances. For example, we currently use CuFFT callbacks in a CUDA C program that performs long FFTs of 8-bit signed integer data (equivalent to Complex{Int8}) and then produce integrated power spectra. Callbacks are used on both inputs and outputs.

Input callbacks

CuFFT does not handle 8-bit data directly so without callbacks one has to pre-convert the input array to 32-bit floats. This increases the size of the CuFFT input buffer by a factor of at least 4 (assuming the CuFFT output buffer can be used to stage the 8-bit input data) compared to using callbacks where the input buffer holds the 8-bit samples and conversion to 32-bit floats happens "on-the-fly" so the 32-bit floats never occupy global memory. This reduced memory requirement allows for longer (or just more) FFTs to be performed.

Output callbacks

The integrated power spectra are produced by integrating (i.e. summing) the magnitude squared (aka abs2()) of the complex outputs of multiple FFTs. Without callbacks, the FFT must write the output samples to global memory and then another kernel must read the output samples that were just written to global memory before computing the magnitude squared and adding it into the integrated power spectra. With callbacks, the magnitude squared calculation and integration can happen "on-the-fly" as the output samples are produced thereby eliminating the need to read the just-written output samples from global memory. This reduces the load on the GPU global memory which can be good for overall throughput.

CuFFT callbacks allow us to do longer FFTs with (potentially) higher throughput than would be possible without callbacks. This makes CUDA.jl support of this feature important for us so we can fully replace the C program with a Julia implementation. Without this feature, we will have to keep the C program around for when the desired FFT lengths with 32-bit input buffers exceed the memory available on the GPU.

@maleadt
Copy link
Member

maleadt commented Jan 8, 2021

I was having another look at the documentation, and:

NOTE:The callback API is available in the statically linked cuFFT library only, and only on 64 bit LINUX operating systems.

so that's a problem in itself. Or maybe that only applies to the helper routines, as cufftXtSetCallback is available in libcufft.so (we just haven't wrapped cufftXt.h -- not sure when and how the Xt functionality is available exactly).

@maleadt
Copy link
Member

maleadt commented Jan 8, 2021

From https://webcache.googleusercontent.com/search?q=cache:RuVLoPCjAx4J:https://quabr.com/40565350/can-i-use-nvrtc-to-compile-cufft-callback-functions-that-use-compiler-generated+&cd=1&hl=en&ct=clnk&gl=be&client=firefox-b-d:



    Generate the source code for the callback function based on the parameters that I'm provided at runtime.

    Use the NVRTC library to compile the source code to PTX.

    Use the CUDA Driver API's linker function to compile the PTX to a cubin.

    Load the resulting cubin using the CUDA Driver API's module management functionality.

    Use the module API to obtain a pointer to the location of the callback function on the device.

    Pass the resulting callback function pointer to the cuFFT library via cufftXtSetCallback().

We have all of that, so it should be possible.

@david-macmahon
Copy link
Contributor

I was having another look at the documentation, and:

NOTE:The callback API is available in the statically linked cuFFT library only, and only on 64 bit LINUX operating systems.

so that's a problem in itself. Or maybe that only applies to the helper routines, as cufftXtSetCallback is available in libcufft.so (we just haven't wrapped cufftXt.h -- not sure when and how the Xt functionality is available exactly).

Wow, yeah, that could be a show stopper. I checked my C program that uses CuFFT callbacks and it looks like it does statically link the CuFFT library. I'll try to build it with dynamically linking to the CuFFT library to see if that works.

@maleadt
Copy link
Member

maleadt commented Jan 8, 2021

I'm also not sure how to look up the function pointer of a kernel, cuModuleGetGlobal doesn't seem to work:

julia> @device_code_ptx kernel = @cuda launch=false identity(nothing)
// PTX CompilerJob of kernel identity(Nothing) for sm_75

//
// Generated by LLVM NVPTX Back-End
//

.version 6.3
.target sm_75
.address_size 64

        // .globl       _Z19julia_identity_2585v // -- Begin function _Z19julia_identity_2585v
.weak .global .align 8 .u64 exception_flag;
                                        // @_Z19julia_identity_2585v
.visible .entry _Z19julia_identity_2585v()
{


// %bb.0:                               // %top
        ret;
                                        // -- End function
}

julia> kernel_global = CuGlobal{Ptr{Cvoid}}(kernel.mod, "_Z19julia_identity_2585v")
ERROR: CUDA error: named symbol not found (code 500, ERROR_NOT_FOUND)

We could always emit a global datastructure that points to the functions in the module, a la https://forums.developer.nvidia.com/t/how-can-i-use-device-function-pointer-in-cuda/14405/18, but that seems cumbersome.

EDIT: looks like in C too you need to assign the device function to a global variable, https://github.com/zchee/cuda-sample/blob/05555eef0d49ebdde999f5430f185a225ef00dcd/7_CUDALibraries/simpleCUFFT_callback/simpleCUFFT_callback.cu#L48-L57.

@david-macmahon
Copy link
Contributor

I don't know that much about the Driver API's module handling, but does CuGlobal() end up calling cuModuleGetGlobal()? Not sure how that differs from cuModuleGetFunction().

FWIW, the C program I use is available here. It uses the runtime API and the callback functions are defined as __device__ functions and then the device function pointers come from cudaMemcpyFromSymbol() (probably just following examples in the CuFFT docs).

@maleadt
Copy link
Member

maleadt commented Jan 8, 2021

cuModuleGetFunction only works for kernels, not for device functions.

Looks like we'll need to emit something like:

static __device__ void callback_f()
{
    return;
}

typedef void (*callback_t)();

 __device__ callback_t callback_alias = callback_f;
.func _ZN61_INTERNAL_39_tmpxft_0000b070_00000000_7_test_cpp1_ii_16bd2dde10callback_fEv
()
;
.global .align 8 .u64 callback_alias = _ZN61_INTERNAL_39_tmpxft_0000b070_00000000_7_test_cpp1_ii_16bd2dde10callback_fEv;

.func _ZN61_INTERNAL_39_tmpxft_0000b070_00000000_7_test_cpp1_ii_16bd2dde10callback_fEv()
{
        ret;
}

@david-macmahon
Copy link
Contributor

I noticed the sample you linked to also uses cudaMemcpyFromSymbol(). I also noticed its Makefile has:

LIBRARIES += -lcufft_static -lculibos

so it too is using the statically linked CuFFT library.

@maleadt
Copy link
Member

maleadt commented Jan 8, 2021

julia> kernel = nothing

julia> @device_code_ptx kernel = cufunction(identity, Tuple{Nothing}; kernel=false)
// PTX CompilerJob of function identity(Nothing) for sm_75

//
// Generated by LLVM NVPTX Back-End
//

.version 6.3
.target sm_75
.address_size 64

        // .globl       julia_identity_4839     // -- Begin function julia_identity_4839
.visible .func julia_identity_4839
()
;
.visible .global .align 8 .u64 entrypoint = julia_identity_4839;
.weak .global .align 8 .u64 exception_flag;
                                        // @julia_identity_4839
.visible .func julia_identity_4839()
{


// %bb.0:                               // %top
        ret;
                                        // -- End function
}

julia> kernel
CUDA.DeviceFunction{identity, Tuple{Nothing}}(CuContext(0x00000000050fb8a0, instance 12bb30420e1458f8), CuModule(Ptr{Nothing} @0x000000000952d580, CuContext(0x00000000050fb8a0, instance 12bb30420e1458f8)), CuPtr{Nothing}(0x0000000000000001))

julia> CuGlobal{Ptr{Cvoid}}(kernel.mod, "entrypoint")[]
Ptr{Nothing} @0x0000000000000001

That 0x1 doesn't look right...

@maleadt
Copy link
Member

maleadt commented Jan 8, 2021

Actually, it looks like cudaMemcpyFromSymbol yields the same value:

#include <stdio.h>

static __device__ void callback_f()
{
    return;
}

typedef void (*callback_t)();

 __device__ callback_t callback_alias = callback_f;

int main() {

    callback_t host_callback = NULL;

    cudaMemcpyFromSymbol(&host_callback,
                                          callback_alias,
                                          sizeof(callback_t));

    printf("callback: %p\n", host_callback);
    return 0;
}
callback: 0x1

@leofang
Copy link

leofang commented Jan 10, 2021

@maleadt I don't know any Julia folks here or how Julia works at all (I would love to!), but I found this issue via Google and I'd really to save you the trouble since this feature took me over a year (even with a couple of NVIDIA power guys' help!) to understand and implement. I finally made it work in CuPy, but it was awfully ugly. In short:

  • You have to link to libcufft_static.a, not libcufft.so.
  • The callback codes need to be known to the cuFFT library at compile/link time. Any attempt to compile it at runtime, retrieve its device function pointer, and pass to the cuFFT library would just not work (this includes cuModuleGetFunction, cudaMemcpyToSymbol, etc).
  • A cuFFT plan handle generated from calls to libcufft.so cannot be passed to functions in libcufft_static.a; consider them separate libraries with no mix-and-match.

CuPy's solution (updated):

  • Normally we have a cuFFT module that is linked to libcufft.so to provide Python wrappers for the cuFFT C APIs.
  • If an attempt to use cuFFT callback is detected at runtime, we compile an entirely new cuFFT module (linked to libcufft_static.a) just for that set of callbacks (load & save), so for each unique set there's a distinct Python module (shared library).

See, for example, cupy/cupy#4105, cupy/cupy#4141, or just search cufft callback in our repo https://github.com/cupy/cupy/pulls?q=is%3Apr+cufft+callback for how we struggled to get it right. It's hightly nontrivial: we know exactly how to make it work in C/C++, but when coupling it to a dynamic environment like Python all kinds of issues appear.

But again, I know nothing about Julia (yet!), so I hope things could be much simpler to you 🤞I'm happy to help (and learn from you). Let me know if you find anything unclear.

@maleadt
Copy link
Member

maleadt commented Jan 11, 2021

@leofang Thanks for chiming in! It's a real bummer though. In Julia/CUDA.jl, we actually ship CUDA libraries and don't assume nvcc or even a host toolchain is available, so dynamically compiling a library that contains both the callbacks and a statically-linked cuFFT isn't a viable option (without a lot of work). 😞

@leofang
Copy link

leofang commented Jan 12, 2021

In Julia/CUDA.jl, we actually ship CUDA libraries and don't assume nvcc or even a host toolchain is available, so dynamically compiling a library that contains both the callbacks and a statically-linked cuFFT isn't a viable option (without a lot of work). 😞

Hi @maleadt, yes we usually don't assume that either, but for cuFFT callbacks a local CUDA Toolkit has to be there unfortunately...

Let me take this opportunity and get to know CUDA.jl a bit more: If a user has a CUDA Toolkit installation locally, can he/she build and use CUDA.jl against that? This was our assumption (we have HPC people asking for this feature, and usually this condition is satisfied in an HPC cluster).

The callbacks have to be very simple kernels in order to see the benefit, though. The gain of saving extra I/O with global memory could be hindered by many factors. Our experience shows that the callbacks are better be as simple as, say, windowing or arithmetic operations over the input data. Any slightly complicated operation (say, compute sin(x) for the input x) would actually make the performance worse compared to without callbacks, at least on a few high-end GPUs. I personally don't use this, and my work on the callbacks was just to support my colleagues and also for exploratory purposes...

@maleadt
Copy link
Member

maleadt commented Jan 12, 2021

Let me take this opportunity and get to know CUDA.jl a bit more: If a user has a CUDA Toolkit installation locally, can he/she build and use CUDA.jl against that? This was our assumption (we have HPC people asking for this feature, and usually this condition is satisfied in an HPC cluster).

Yes, by setting the JULIA_CUDA_USE_BINARYBUILDER environment variable to false. By default, we download CUDA ourselves, because users are pretty bad at making sure their toolkit matches the CUDA support level of their driver. In the future, we'll probably use a more integrated feature (like Overrides), but this'll do for now.

we usually don't assume that either, but for cuFFT callbacks a local CUDA Toolkit has to be there unfortunately...

So if I understand correctly (I haven't read through the entire PR yet), you dynamically create a shared library from cupy_cufftXt.cu that contains the cuFFT static library, the callbacks, and code to look up the callback and configure the plan. Do you know if this is also possible with PTX-based callbacks? We never emit CUDA C code, but compile Julia directly to PTX.

@david-macmahon
Copy link
Contributor

I checked my C program that uses CuFFT callbacks and it looks like it does statically link the CuFFT library. I'll try to build it with dynamically linking to the CuFFT library to see if that works.

In the interest of completeness I linked my program that uses CuFFT callbacks with the dynamic CuFFT library. Not surprisingly, it linked OK but at runtime the cufftXtSetCallback() function returned a CUFFT_NOT_IMPLEMENTED error.

@leofang
Copy link

leofang commented Jan 17, 2021

Hi @maleadt @david-macmahon Sorry I dropped the ball. Had a rough week 😞

By default, we download CUDA ourselves, because users are pretty bad at making sure their toolkit matches the CUDA support level of their driver.

Oh wow. Yes, I agree this is always a big common pitfall.

I wonder how JuliaGPU resolved the license issue, though. Do you get some kind of special permission from NVIDIA? We wanted to do similar things too, mainly because there're some device function headers (such as those for cuRAND device routines) that come with CUDA Toolkit and we want to distribute, but by default most of the headers in CUDA Toolkit are not redistributable, so we had to change the code design just to avoid the distribution problem, which is silly. For runtimes that can JIT, not allowing redistributing headers couldn't be more annoying. 😞

So if I understand correctly (I haven't read through the entire PR yet), you dynamically create a shared library from cupy_cufftXt.cu that contains the cuFFT static library, the callbacks, and code to look up the callback and configure the plan. Do you know if this is also possible with PTX-based callbacks? We never emit CUDA C code, but compile Julia directly to PTX.

I actually took a quick look at GPUCompiler.jl. That's an amazing project! Did you guys delegate to LLVM/NVVM for the last mile of compilation? If so, it sounds similar to Numba's approach. Just for my own curiosity 🙂

Regardless of the compiler internal, I doubt it'd work. It's because in the end your compiler (and so does NVRTC which CuPy uses) returns ptx/cubin in which the device functions will be loaded via cuModuleGetFunction to get their pointers. But IIUC this circles back to No.2 of my earlier comment. At compile time, the callback device function has to be passed to cufftXtSetCallback so that the unused symbols can be identified and pruned by the linker (otherwise at Python import time we get tons of undefined symbols leaked from libcufft_static.a; not sure what would happen with Julia, though), but it's a host function so it's just incompatible with the whole compilation workflow that CUDA.jl or CuPy adopts.

In the interest of completeness I linked my program that uses CuFFT callbacks with the dynamic CuFFT library. Not surprisingly, it linked OK but at runtime the cufftXtSetCallback() function returned a CUFFT_NOT_IMPLEMENTED error.

Yes, this is why No.3 of my earlier comment was about. I also noticed that, so one attempt I did was to generate a cuFFT plan from the plan-generation wrapper linked to the shared library, and then pass then plan with callbacks etc to the wrapper of cufftXtSetCallback() linked to the static library, an approach I called "mix-and-match" earlier. If it works, it'd also be possible to circumvent the incompatible workflow issue, but unfortunately it just does not work. So the lesson is once any callbacks are involved, all cuFFT API calls must from the same library, namely the static one, and so leading to CuPy's per-callback module solution...

@maleadt
Copy link
Member

maleadt commented Jan 18, 2021

I wonder how JuliaGPU resolved the license issue, though. Do you get some kind of special permission from NVIDIA?

Most of the stuff we need falls under the redistribution license, but yeah for the rest we got in touch with NVIDIA.

I actually took a quick look at GPUCompiler.jl. That's an amazing project! Did you guys delegate to LLVM/NVVM for the last mile of compilation?

Thanks! Yes, we use LLVM for the compilation to PTX. NVVM isn't really usable; even on CUDA 11.2 it only supports LLVM 7.0's IR, whereas the upcoming Julia 1.6 is using LLVM 11. But worse, we want to support a variety of CUDA drivers and consequently CUDA toolkits, which means we'd have to support multiple versions of NVVM IR (aka. LLVM IR). And since there isn't really a way to downgrade LLVM IR, that just isn't a viable option. We've mentioned this to NVIDIA on countless occasions (please approach NVVM differently, or contribute to LLVM's NVPTX), but they don't budge 😞

@david-macmahon
Copy link
Contributor

Sorry-not-sorry for reviving this 18+ month old issue, but I recently encountered another use case where cuFFT callbacks could (I think) really boost performance. My application needs to multiply each element of an input matrix by a function of the element's indices prior to computing the FFT of the resulting element-wise product matrix. I now do this by a broadcast operation that effectively reads the (large) input matrix from main GPU memory, performs the complex multiply, then writes the results back to main GPU memory, then performs an FFT of the matrix from main GPU memory. With cuFFT callbacks, this element-wise pre-multiplication could be performed as cuFFT fetches the data from main GPU memory, thereby saving a complete read and write of the (large) matrix data from/to main GPU memory.

But I understand that cuFFT callbacks currently require static linking, so I guess this is really more of a cuFFT issue for NVIDIA (to support cuFFT callbacks with dynamically linked libcufft) rather than a CUDA.jl issue. 😞

@mrsamuel76
Copy link

I am trying to implement power spectral density calculation in the same way it was mentioned above. Check the David Macmahon's comment from Jan08, 2021. (#75 (comment)) I wonder what the most efficient way to implement the store (output) callback is. I am using atomicAdd() to correctly sum up the magnitude squared from all batch members. Anyone came up with a better implementation? :-)

@david-macmahon
Copy link
Contributor

I think the use of cuFFT callbacks from Julia is off the table unless NVIDIA comes up with a more flexible approach to callbacks that doesn't require statically linking with libcufft_static.a. You're probably best off using Base.mapreducedim!(abs2, +, psdarray, fftout) to compute the PSD (assuming you can spare the memory), where psdarray has the same dimensions as fftout except for the batch dimension (or maybe a singleton dimension in its place?).

@leofang
Copy link

leofang commented Mar 15, 2024

I think the use of cuFFT callbacks from Julia is off the table unless NVIDIA comes up with a more flexible approach to callbacks that doesn't require statically linking with libcufft_static.a.

Yes, we (NVIDIA) hear you and it's finally possible now 🙂 See cupy/cupy#8242 for example.

@maleadt
Copy link
Member

maleadt commented Mar 15, 2024

That's great! How do we generate LTO IR? Our current toolchain targets PTX through LLVM's NVPTX; compiling Julia to C++ for use with NVRTC is not possible, and targeting NVVM would also require lots of work.

@leofang
Copy link

leofang commented Mar 15, 2024

To my knowledge currently only NVVM and NVRTC can emit LTO IR. I doubt NVPTX understands this format.

@leofang
Copy link

leofang commented Mar 15, 2024

and targeting NVVM would also require lots of work

Thinking about it more, @maleadt would you be able to elaborate on the challenges on Julia's side so that I can understand better?

IIRC Numba manages to downgrade and then translate LLVM IR to NVVM IR (ex: here), so that it can use NVVM to emit PTX (it's arguably a hack for sure). Now it is also adding support for LTO IR (here), with the help of pynvjitlink (Python wrapper for nvJitLink), and our internal tests indicate that it works fine.

@maleadt
Copy link
Member

maleadt commented Mar 16, 2024

Thinking about it more, @maleadt would you be able to elaborate on the challenges on Julia's side so that I can understand better?

Switching to a different back-end is a lot of work (the intrinsics are not compatible, we would need to make sure performance of the generated code is similar, etc), and I'm hesitant migrating from an open-source back-end we can debug/improve to a closed-source one. But regardless of that, there's a couple of practical issues too:

  • we'd need to maintain an proper IR downgrader; string-replacing textural IR isn't an acceptable solution (e.g., what will happen once opaque pointers land). We're working on one now for Metal, but that's targeting LLVM 5.
  • we support several versions of CUDA, currently 11.4-12.4, which would imply supporting several versions of NVVM IR. That is a lot of work (lots of conditional code, multiple IR downgraders, etc)
  • alternatively, we could just always redistribute and use the latest libNVVM and use that to target all CUDA versions. that has issues too:
    • the generated code may be incompatible with the actual driver, by using a too recent PTX ISA (see bug 4159797). this could be remedied by making the target ISA configurable
    • I presume there may also be incompatibilities with the rest of the toolkit, e.g., would the LTO IR generated by NVVM from CUDA 13.x be compatible with cuFFT from CUDA 12.5? would we just need to ensure we're using the latest libnvjitlink too?

Despite of these issues, I am planning to experiment with NVVM as soon as we have the necessary ISA configurability (and have created https://github.com/JuliaGPU/NVVM.jl as a starting point), but I don't expect to be switching to it any time soon unless there's a very compelling reason to do so. cuFFT callbacks, albeit very useful, don't seem important enough to justify the switch. If e.g. performance would be a lot better, that would be a different story.

@llukas
Copy link

llukas commented Nov 18, 2024

@maleadt we're you able to experiment with cuFFT LTO callbacks? It is now production feature: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html

@maleadt
Copy link
Member

maleadt commented Nov 18, 2024

No, we currently can't use NVVM due to missing features (see JuliaGPU/NVVM.jl#1 (comment)) which seems to be a requirement for using cuFFT callbacks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda kernels Stuff about writing CUDA kernels. enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

6 participants