Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider supporting more non-isbitstype arguments #639

Open
maleadt opened this issue Oct 8, 2024 · 0 comments
Open

Consider supporting more non-isbitstype arguments #639

maleadt opened this issue Oct 8, 2024 · 0 comments

Comments

@maleadt
Copy link
Member

maleadt commented Oct 8, 2024

Our current validator refuses passing non-isbitstype arguments, with the exception of arguments whose type passes the Core.Compiler.isconstType test. This makes it possible to, e.g., broadcast types as these arguments are only used to specialize the kernel, and not actually used by the generated code (even though they are passed, as opposed to ghost/singleton values).


In JuliaGPU/CUDA.jl#2514, it was noted that some code (notably closure-heavy code generated by Zygote) still refuses to compile, even though the generated code doesn't actually use the non-isbits value. For example:

struct Bar{T}
    a::T
end

function main()
    a = cu(zeros(5))

    capture = Bar
    function closure(arg)
        capture(arg)
    end

    function kernel(f, x)
        f(x[])
        return
    end
    @cuda kernel(closure, a)
end

The problem here is that the closure captures the type, making the closure non-isbits too. But because the closure is not a const type, we fail compilation. Even though the generated code is perfectly fine:

define ptx_kernel void @_Z6kernel7closureI4TypeI3BazI1TEEE13CuDeviceArrayI7Float32Li1ELi1EE({ i64, i32 } %state, [1 x {}*] %0, { i8 addrspace(1)*, i64, [1 x i64], i64 } %1) local_unnamed_addr {
conversion:
  ret void
}

Note how the closure argument does really contain a managed pointer. In this case, we can work around the issue by reviving the more lenient validation removed in #24 where we not only checked for Core.Compiler.isconstType, but also if the value is actually used:

diff --git a/src/driver.jl b/src/driver.jl
index 9e05eb6..a4cff8f 100644
--- a/src/driver.jl
+++ b/src/driver.jl
@@ -88,8 +88,7 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
     end

     @timeit_debug to "Validation" begin
-        check_method(job)   # not optional
-        validate && check_invocation(job)
+        check_method(job)
     end

     prepare_job!(job)
@@ -99,6 +98,10 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool

     ir, ir_meta = emit_llvm(job; libraries, toplevel, optimize, cleanup, only_entry, validate)

+    validate && @timeit_debug to "Validation" begin
+        check_invocation(job, ir_meta.entry)
+    end
+
     if output == :llvm
         if strip
             @timeit_debug to "strip debug info" strip_debuginfo!(ir)
diff --git a/src/validation.jl b/src/validation.jl
index e1a355b..9f1f869 100644
--- a/src/validation.jl
+++ b/src/validation.jl
@@ -66,7 +66,7 @@ function explain_nonisbits(@nospecialize(dt), depth=1; maxdepth=10)
     return msg
 end

-function check_invocation(@nospecialize(job::CompilerJob))
+function check_invocation(@nospecialize(job::CompilerJob), entry::LLVM.Function)
     sig = job.source.specTypes
     ft = sig.parameters[1]
     tt = Tuple{sig.parameters[2:end]...}
@@ -77,6 +77,9 @@ function check_invocation(@nospecialize(job::CompilerJob))
     real_arg_i = 0

     for (arg_i,dt) in enumerate(sig.parameters)
+        println(Core.stdout, arg_i)
+        println(Core.stdout, dt)
+
         isghosttype(dt) && continue
         Core.Compiler.isconstType(dt) && continue
         real_arg_i += 1
@@ -89,9 +92,13 @@ function check_invocation(@nospecialize(job::CompilerJob))
         end

         if !isbitstype(dt)
-            throw(KernelError(job, "passing and using non-bitstype argument",
-                """Argument $arg_i to your kernel function is of type $dt, which is not isbits:
-                    $(explain_nonisbits(dt))"""))
+            param = parameters(entry)[real_arg_i]
+            if !isempty(uses(param))
+                println(Core.stdout, string(entry))
+                throw(KernelError(job, "passing and using non-bitstype argument",
+                      """Argument $arg_i to your kernel function is of type $dt, which is not isbits:
+                         $(explain_nonisbits(dt))"""))
+             end
         end
     end

Sadly, this approach is insufficient for more complex cases such as:

struct Bar{T}
    a::T
    b::T
end

function main2()
    foo(f) = (args...) -> f(args...)

    a = cu(zeros(5)); b = cu(ones(5)); c = Bar{Float32}; d = foo(c)
    foo(c).(a, b)
end
Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}}, var"#3#5"{Type{Bar{Float32}}}, Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits:
  .f is of type var"#3#5"{Type{Bar{Float32}}} which is not isbits.
    .f is of type Type{Bar{Float32}} which is not isbits.
define ptx_kernel void @_Z3_3415CuKernelContext13CuDeviceArrayI15BrokenBroadcastI3AnyELi1ELi1EE11BroadcastedI12CuArrayStyleILi1E12DeviceMemoryE5TupleI5OneToI5Int64EE2_3I4TypeI3BarI7Float32EEES9_I8ExtrudedIS0_ISH_Li1ELi1EES9_I4BoolES9_ISB_EESQ_EESB_({ i64, i32 } %state, { i8 addrspace(1)*, i64, [1 x i64], i64 } %0, { [1 x {}*], [2 x { { i8 addrspace(1)*, i64, [1 x i64], i64 }, [1 x i8], [1 x i64] }], [1 x [1 x i64]] } %1, i64 signext %2) local_unnamed_addr {
conversion:
  %.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [1 x i64], i64 } %0, 3
  %.not7 = icmp slt i64 %2, 1
  br i1 %.not7, label %common.ret, label %L5.lr.ph

L5.lr.ph:                                         ; preds = %conversion
  %3 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %4 = add nuw nsw i32 %3, 1
  %5 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %6 = zext i32 %5 to i64
  %7 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
  %8 = zext i32 %7 to i64
  %9 = mul nuw nsw i64 %6, %8
  %10 = zext i32 %4 to i64
  %11 = add nuw nsw i64 %9, %10
  %12 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
  %13 = mul i32 %12, %7
  %14 = sext i32 %13 to i64
  br label %L5

L5:                                               ; preds = %L5, %L5.lr.ph
  %value_phi8 = phi i64 [ 1, %L5.lr.ph ], [ %20, %L5 ]
  %15 = add i64 %value_phi8, -1
  %16 = mul i64 %15, %14
  %17 = add i64 %11, %16
  %18 = icmp slt i64 %17, 1
  %19 = icmp sgt i64 %17, %.fca.3.extract
  %spec.select = select i1 %18, i1 true, i1 %19
  %20 = add i64 %value_phi8, 1
  %.not = icmp sgt i64 %20, %2
  %or.cond = select i1 %spec.select, i1 true, i1 %.not
  br i1 %or.cond, label %common.ret, label %L5

common.ret:                                       ; preds = %L5, %conversion
  ret void
}

Note how the non-isbits Broadcasted argument is used, so it also fails the more lenient validation check, but it's just not the managed pointer that's being used.


I'm not sure how to proceed this. Simply removing the validation and trusting that other aspects of IR validation will error seems too optimistic -- IIRC we introduced this check to prevent accidentally reading CPU memory from the GPU. And actually detecting whether the managed pointer field is the one being used seems hard.

I'm also not sure how important that is; we've not received many bug reports about this, and the motivating example by @BioTurboNick would simply fail after validation anyway because it involves a broken broadcast (producing Any values). So maybe this isn't very important.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant