Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fails to differentiate integers powers of Metal array #1533

Open
CarloLucibello opened this issue Oct 17, 2024 · 1 comment
Open

fails to differentiate integers powers of Metal array #1533

CarloLucibello opened this issue Oct 17, 2024 · 1 comment

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Oct 17, 2024

I see the following error on this simple example involving Metal arrays

julia> using Metal

julia> x = Metal.ones(2)

julia> gradient(x -> sum(x.^2), x)
ERROR: InvalidIRError: compiling MethodInstance for (::Metal.var"#broadcast_cartesian_static#213")(::MtlDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Metal.StaticCartesianIndices{…}) resulted in invalid LLVM IR
Reason: unsupported use of double value
Stacktrace:
 [1] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] #power_by_squaring#526
   @ ./intfuncs.jl:0
 [2] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] #power_by_squaring#526
   @ ./intfuncs.jl:0
 [2] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] *
   @ ./float.jl:493
 [2] #power_by_squaring#526
   @ ./intfuncs.jl:320
 [3] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
Stacktrace:
 [1] ==
   @ ./float.jl:616
 [2] isone
   @ ./number.jl:62
 [3] #power_by_squaring#526
   @ ./intfuncs.jl:322
 [4] multiple call sites
   @ unknown:0
Reason: unsupported use of double value
.... repeated multiple times .......
Stacktrace:
  [1] Float32
    @ ./float.jl:338
  [2] ^
    @ ./math.jl:1231
  [3] _broadcast_getindex_evalf
    @ ./broadcast.jl:673
  [4] _broadcast_getindex
    @ ./broadcast.jl:646
  [5] _getindex
    @ ./broadcast.jl:670
  [6] _broadcast_getindex
    @ ./broadcast.jl:645
  [7] _getindex
    @ ./broadcast.jl:670
  [8] _getindex
    @ ./broadcast.jl:669
  [9] _broadcast_getindex
    @ ./broadcast.jl:645
 [10] getindex
    @ ./broadcast.jl:605
 [11] broadcast_cartesian_static
    @ ~/.julia/packages/Metal/rBb1i/src/broadcast.jl:67
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, args::LLVM.Module)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/validation.jl:147
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:382 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/NRdsv/src/TimerOutput.jl:253 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:381 [inlined]
  [5] 
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:108
  [6] 
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:100
  [7] codegen
    @ ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:82 [inlined]
  [8] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:79
  [9] compile
    @ ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:74 [inlined]
 [10] (::Metal.var"#154#162"{GPUCompiler.CompilerJob{}})(ctx::LLVM.Context)
    @ Metal ~/.julia/packages/Metal/rBb1i/src/compiler/compilation.jl:108
 [11] JuliaContext(f::Metal.var"#154#162"{GPUCompiler.CompilerJob{}}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:34
 [12] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/driver.jl:25
 [13] macro expansion
    @ ~/.julia/packages/Metal/rBb1i/src/compiler/compilation.jl:107 [inlined]
 [14] macro expansion
    @ ~/.julia/packages/ObjectiveC/C7BVt/src/os.jl:264 [inlined]
 [15] compile(job::GPUCompiler.CompilerJob)
    @ Metal ~/.julia/packages/Metal/rBb1i/src/compiler/compilation.jl:105
 [16] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(Metal.compile), linker::typeof(Metal.link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/execution.jl:237
 [17] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2CW9L/src/execution.jl:151
 [18] macro expansion
    @ ~/.julia/packages/Metal/rBb1i/src/compiler/execution.jl:189 [inlined]
 [19] macro expansion
    @ ./lock.jl:273 [inlined]
 [20] mtlfunction(f::Metal.var"#broadcast_cartesian_static#213", tt::Type{Tuple{…}}; name::Nothing, kwargs::@Kwargs{})
    @ Metal ~/.julia/packages/Metal/rBb1i/src/compiler/execution.jl:184
 [21] mtlfunction(f::Metal.var"#broadcast_cartesian_static#213", tt::Type{Tuple{…}})
    @ Metal ~/.julia/packages/Metal/rBb1i/src/compiler/execution.jl:182
 [22] macro expansion
    @ ~/.julia/packages/Metal/rBb1i/src/compiler/execution.jl:85 [inlined]
 [23] _copyto!
    @ ~/.julia/packages/Metal/rBb1i/src/broadcast.jl:74 [inlined]
 [24] copyto!
    @ ~/.julia/packages/Metal/rBb1i/src/broadcast.jl:47 [inlined]
 [25] copy
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:29 [inlined]
 [26] materialize
    @ ./broadcast.jl:867 [inlined]
 [27] (::Zygote.var"#1257#1260"{2, MtlVector{Float32, Metal.PrivateStorage}})(ȳ::MtlVector{Float32, Metal.PrivateStorage})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/lib/broadcast.jl:108
 [28] #3916#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [29] #631
    @ ./REPL[9]:1 [inlined]
 [30] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface.jl:91
 [31] gradient(f::Function, args::MtlVector{Float32, Metal.PrivateStorage})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface.jl:148
 [32] top-level scope
    @ REPL[9]:1
Some type information was truncated. Use `show(err)` to see complete types.

I get a similar error for x.^3 but not for x.^2f0. cc @maleadt

@mcabbott
Copy link
Member

I think x.^2 and x.^3 should go here... is x .^ (p - 1) producing Float64 somehow?

@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
y = Base.literal_pow.(^, x, exp)
y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants