From 514c506099a19801c47a652bec0b2295bb5a7de5 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 14 Apr 2023 13:15:07 +0200 Subject: [PATCH] Generalize to arbitrary args. --- src/compiler/compilation.jl | 36 ++++++++++++++++++++++++++++++++++-- test/core/execution.jl | 17 +++++++++-------- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index 138591ac0c..1009ce8aa1 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -538,9 +538,41 @@ function generate_opaque_closure(config::CompilerConfig, src::CodeInfo, return OpaqueClosure{id, typeof(env), sig, rt}(env) end +# generated function `ccall`, working around the restriction that ccall type +# tuples need to be literals. this relies on ccall internals... +@inline @generated function generated_ccall(f::Ptr, _rettyp, _types, vals...) + ex = quote end + + rettyp = _rettyp.parameters[1] + types = _types.parameters[1].parameters + args = [:(vals[$i]) for i in 1:length(vals)] + + # cconvert + cconverted = [Symbol("cconverted_$i") for i in 1:length(vals)] + for (dst, typ, src) in zip(cconverted, types, args) + append!(ex.args, (quote + $dst = Base.cconvert($typ, $src) + end).args) + end + + # unsafe_convert + unsafe_converted = [Symbol("unsafe_converted_$i") for i in 1:length(vals)] + for (dst, typ, src) in zip(unsafe_converted, types, cconverted) + append!(ex.args, (quote + $dst = Base.unsafe_convert($typ, $src) + end).args) + end + + call = Expr(:foreigncall, :f, rettyp, Core.svec(types...), 0, + QuoteNode(:ccall), unsafe_converted..., cconverted...) + push!(ex.args, call) + return ex +end + # device-side call to an opaque closure -function (oc::OpaqueClosure{F})(a, b) where F +function (oc::OpaqueClosure{F,E,A,R})(args...) where {F,E,A,R} ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), F) assume(ptr != C_NULL) - return ccall(ptr, Int, (Int, Int), a, b) + #ccall(ptr, R, (A...), args...) + generated_ccall(ptr, R, A, args...) end diff --git a/test/core/execution.jl b/test/core/execution.jl index 716d094a34..6326f3ff1c 100644 --- a/test/core/execution.jl +++ b/test/core/execution.jl @@ -1118,21 +1118,22 @@ end # basic closure, constructed from CodeInfo let - ir, rettyp = only(Base.code_typed(+, (Int, Int))) + ir, rettyp = only(Base.code_typed(*, (Int, Int, Int))) oc = CUDA.OpaqueClosure(ir) - c = CuArray([0]) - a = CuArray([1]) - b = CuArray([2]) + d = CuArray([1]) + a = CuArray([2]) + b = CuArray([3]) + c = CuArray([4]) - function kernel(oc, c, a, b) + function kernel(oc, d, a, b, c) i = threadIdx().x - @inbounds c[i] = oc(a[i], b[i]) + @inbounds d[i] = oc(a[i], b[i], c[i]) return end - @cuda threads=1 kernel(oc, c, a, b) + @cuda threads=1 kernel(oc, d, a, b, c) - @test Array(c)[] == 3 + @test Array(d)[] == 24 end end