Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turing guaranteed to error (when compiled with GPUCompiler) #2240

Open
wsmoses opened this issue May 28, 2024 · 7 comments
Open

Turing guaranteed to error (when compiled with GPUCompiler) #2240

wsmoses opened this issue May 28, 2024 · 7 comments

Comments

@wsmoses
Copy link
Collaborator

wsmoses commented May 28, 2024

julia> Core.Compiler.typeinf_ext_toplevel(interp, mi)
CodeInfo(
     @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:93 within `logdensity`
    ┌ @ Base.jl:37 within `getproperty`
1 ──│ %1  = Base.getfield(f, :varinfo)::TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}
│   └
│   ┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/abstract_varinfo.jl:747 within `unflatten` @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:134 @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:137
│   │┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:116 within `VarInfo`
│   ││┌ @ Base.jl:37 within `getproperty`
│   │││ %2  = Base.getfield(%1, :metadata)::@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}
│   ││└
│   ││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:158 within `newmetadata`
│   │││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl within `macro expansion`
│   ││││┌ @ Base.jl:37 within `getproperty`
│   │││││ %3  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %4  = Base.getfield(%3, :idcs)::Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}
│   │││││ %5  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %6  = Base.getfield(%5, :vns)::Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}
│   │││││ %7  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %8  = Base.getfield(%7, :ranges)::Vector{UnitRange{Int64}}
│   │││││ %9  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %10 = Base.getfield(%9, :ranges)::Vector{UnitRange{Int64}}
│   ││││└
│   ││││ %11 = DynamicPPL.length::typeof(length)
│   ││││┌ @ reducedim.jl:1011 within `sum`
│   │││││┌ @ reducedim.jl:1011 within `#sum#829`
│   ││││││┌ @ reducedim.jl:1015 within `_sum`
│   │││││││┌ @ reducedim.jl:1015 within `#_sum#831`
│   ││││││││ %12 = Base.add_sum::typeof(Base.add_sum)
│   ││││││││┌ @ reducedim.jl:357 within `mapreduce`
│   │││││││││┌ @ reducedim.jl:357 within `#mapreduce#821`
│   ││││││││││┌ @ reducedim.jl:365 within `_mapreduce_dim`
│   │││││││││││ %13 = invoke Base._mapreduce(%11::typeof(length), %12::typeof(Base.add_sum), $(QuoteNode(IndexLinear()))::IndexLinear, %10::Vector{UnitRange{Int64}})::Int64
│   ││││└└└└└└└
│   ││││┌ @ int.jl:87 within `+`
│   │││││ %14 = Base.add_int(0, %13)::Int64
│   ││││└
│   ││││┌ @ range.jl:5 within `Colon`
│   │││││┌ @ range.jl:403 within `UnitRange`
│   ││││││┌ @ range.jl:414 within `unitrange_last`
│   │││││││┌ @ operators.jl:425 within `>=`
│   ││││││││┌ @ int.jl:514 within `<=`
│   │││││││││ %15 = Base.sle_int(1, %14)::Bool
│   │││││││└└
└───│││││││       goto #3 if not %15
2 ──│││││││       goto #4
3 ──│││││││       goto #4
    ││││││└
4 ┄─││││││ %19 = φ (#2 => %14, #3 => 0)::Int64
│   ││││││ %20 = %new(UnitRange{Int64}, 1, %19)::UnitRange{Int64}
└───││││││       goto #5
5 ──││││││       goto #6
    ││││└└
    ││││┌ @ array.jl:973 within `getindex`
6 ──│││││       goto #11 if not true
    │││││┌ @ abstractarray.jl:700 within `checkbounds`
7 ──││││││ %24 = Core.tuple(%20)::Tuple{UnitRange{Int64}}
│   ││││││ @ abstractarray.jl:702 within `checkbounds` @ abstractarray.jl:687
│   ││││││┌ @ abstractarray.jl:389 within `eachindex`
│   │││││││┌ @ abstractarray.jl:137 within `axes1`
│   ││││││││┌ @ abstractarray.jl:98 within `axes`
│   │││││││││┌ @ array.jl:191 within `size`
│   ││││││││││ %25 = Base.arraysize(θ, 1)::Int64
│   │││││││││└
│   │││││││││┌ @ tuple.jl:291 within `map`
│   ││││││││││┌ @ range.jl:469 within `oneto`
│   │││││││││││┌ @ range.jl:467 within `OneTo` @ range.jl:454
│   ││││││││││││┌ @ promotion.jl:532 within `max`
│   │││││││││││││┌ @ int.jl:83 within `<`
│   ││││││││││││││ %26 = Base.slt_int(%25, 0)::Bool
│   │││││││││││││└
│   │││││││││││││┌ @ essentials.jl:647 within `ifelse`
│   ││││││││││││││ %27 = Core.ifelse(%26, 0, %25)::Int64
│   ││││││└└└└└└└└
│   ││││││┌ @ abstractarray.jl:768 within `checkindex`
│   │││││││┌ @ range.jl:672 within `isempty`
│   ││││││││┌ @ operators.jl:378 within `>`
│   │││││││││┌ @ int.jl:83 within `<`
│   ││││││││││ %28 = Base.slt_int(%19, 1)::Bool
│   │││││││└└└
│   │││││││ @ abstractarray.jl:768 within `checkindex` @ abstractarray.jl:763
│   │││││││┌ @ int.jl:86 within `-`
│   ││││││││ %29 = Base.sub_int(1, 1)::Int64
│   │││││││└
│   │││││││┌ @ essentials.jl:524 within `unsigned`
│   ││││││││┌ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %30 = Base.bitcast(UInt64, %29)::UInt64
│   │││││││││ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %31 = Base.bitcast(UInt64, %27)::UInt64
│   │││││││└└
│   │││││││┌ @ int.jl:513 within `<`
│   ││││││││ %32 = Base.ult_int(%30, %31)::Bool
│   │││││││└
│   │││││││┌ @ int.jl:86 within `-`
│   ││││││││ %33 = Base.sub_int(%19, 1)::Int64
│   │││││││└
│   │││││││┌ @ essentials.jl:524 within `unsigned`
│   ││││││││┌ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %34 = Base.bitcast(UInt64, %33)::UInt64
│   │││││││││ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %35 = Base.bitcast(UInt64, %27)::UInt64
│   │││││││└└
│   │││││││┌ @ int.jl:513 within `<`
│   ││││││││ %36 = Base.ult_int(%34, %35)::Bool
│   │││││││└
│   │││││││ @ abstractarray.jl:768 within `checkindex`
│   │││││││┌ @ bool.jl:38 within `&`
│   ││││││││ %37 = Base.and_int(%32, %36)::Bool
│   │││││││└
│   │││││││┌ @ bool.jl:39 within `|`
│   ││││││││ %38 = Base.or_int(%28, %37)::Bool
│   ││││││└└
│   ││││││ @ abstractarray.jl:702 within `checkbounds`
└───││││││       goto #9 if not %38
8 ──││││││       goto #10
9 ──││││││       invoke Base.throw_boundserror::Vector{Float64}, %24::Tuple{UnitRange{Int64}})::Union{}
└───││││││       unreachable
10 ─││││││       nothing::Nothing
    │││││└
    │││││ @ array.jl:974 within `getindex`
    │││││┌ @ range.jl:761 within `length`
    ││││││┌ @ int.jl:86 within `-`
11 ┄│││││││ %44 = Base.sub_int(%19, 1)::Int64
│   ││││││└
│   ││││││┌ @ int.jl:87 within `+`
│   │││││││ %45 = Base.add_int(1, %44)::Int64
│   │││││└└
│   │││││ @ array.jl:975 within `getindex`
│   │││││┌ @ range.jl:706 within `axes`
│   ││││││┌ @ range.jl:761 within `length`
│   │││││││┌ @ int.jl:86 within `-`
│   ││││││││ %46 = Base.sub_int(%19, 1)::Int64
│   │││││││└
│   │││││││┌ @ int.jl:87 within `+`
│   ││││││││ %47 = Base.add_int(1, %46)::Int64
│   ││││││└└
│   ││││││┌ @ range.jl:469 within `oneto`
│   │││││││┌ @ range.jl:467 within `OneTo` @ range.jl:454
│   ││││││││┌ @ promotion.jl:532 within `max`
│   │││││││││┌ @ int.jl:83 within `<`
│   ││││││││││ %48 = Base.slt_int(%47, 0)::Bool
│   │││││││││└
│   │││││││││┌ @ essentials.jl:647 within `ifelse`
│   ││││││││││ %49 = Core.ifelse(%48, 0, %47)::Int64
│   │││││└└└└└
│   │││││┌ @ abstractarray.jl:831 within `similar` @ array.jl:420
│   ││││││┌ @ boot.jl:486 within `Array` @ boot.jl:477
│   │││││││ %50 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Float64}, svec(Any, Int64), 0, :(:ccall), Vector{Float64}, :(%49), :(%49)))::Vector{Float64}
│   │││││└└
│   │││││ @ array.jl:976 within `getindex`
│   │││││┌ @ operators.jl:378 within `>`
│   ││││││┌ @ int.jl:83 within `<`
│   │││││││ %51 = Base.slt_int(0, %45)::Bool
│   │││││└└
└───│││││       goto #13 if not %51
    │││││ @ array.jl:977 within `getindex`
    │││││┌ @ array.jl:368 within `copyto!`
12 ─││││││       invoke Base._copyto_impl!(%50::Vector{Float64}, 1::Int64, θ::Vector{Float64}, 1::Int64, %45::Int64)::Vector{Float64}
    │││││└
    │││││ @ array.jl:979 within `getindex`
13 ┄│││││       goto #14
    ││││└
    ││││┌ @ Base.jl:37 within `getproperty`
14 ─│││││ %55 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %56 = Base.getfield(%55, :dists)::Vector{IsoNormal}
│   │││││ %57 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %58 = Base.getfield(%57, :gids)::Vector{Set{DynamicPPL.Selector}}
│   │││││ %59 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %60 = Base.getfield(%59, :orders)::Vector{Int64}
│   │││││ %61 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %62 = Base.getfield(%61, :flags)::Dict{String, BitVector}
│   ││││└
│   ││││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:47 within `Metadata`
│   │││││ %63 = %new(DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, %4, %6, %8, %50, %56, %58, %60, %62)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   ││││└
│   ││││┌ @ boot.jl:622 within `NamedTuple`
│   │││││ %64 = %new(@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, %63)::@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}
│   ││││└
└───││││       goto #15
    ││└└
    ││ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:117 within `VarInfo`
    ││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:906 within `getlogp`
    │││┌ @ Base.jl:37 within `getproperty`
15 ─││││ %66 = Base.getfield(%1, :logp)::Base.RefValue{Float64}
│   │││└
│   │││┌ @ refvalue.jl:59 within `getindex`
│   ││││┌ @ Base.jl:37 within `getproperty`
│   │││││ %67 = Base.getfield(%66, :x)::Float64
│   ││└└└
│   ││┌ @ refvalue.jl:8 within `RefValue`
│   │││ %68 = %new(Base.RefValue{Float64}, %67)::Base.RefValue{Float64}
│   ││└
│   ││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:923 within `get_num_produce`
│   │││┌ @ Base.jl:37 within `getproperty`
│   ││││ %69 = Base.getfield(%1, :num_produce)::Base.RefValue{Int64}
│   │││└
│   │││┌ @ refvalue.jl:59 within `getindex`
│   ││││┌ @ Base.jl:37 within `getproperty`
│   │││││ %70 = Base.getfield(%69, :x)::Int64
│   ││└└└
│   ││┌ @ refpointer.jl:137 within `Ref`
│   │││┌ @ refvalue.jl:10 within `RefValue` @ refvalue.jl:8
│   ││││ %71 = %new(Base.RefValue{Int64}, %70)::Base.RefValue{Int64}
│   ││└└
│   ││ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:117 within `VarInfo` @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:100
│   ││ %72 = %new(TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, %64, %68, %71)::TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}
│   ││ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:117 within `VarInfo`
└───││       goto #16
16 ─││       goto #17
17 ─││       goto #18
18 ─││       goto #19
    └└
     @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:94 within `logdensity`
19 ─       invoke DynamicPPL.evaluate!!($(QuoteNode(Model{typeof(demo2), (Symbol("##arg#225"),), (), (), Tuple{DynamicPPL.TypeWrap{Matrix{Float64}}}, Tuple{}, DefaultContext}(demo2, (var"##arg#225" = DynamicPPL.TypeWrap{Matrix{Float64}}(),), NamedTuple(), DefaultContext())))::Model{typeof(demo2), (Symbol("##arg#225"),), (), (), Tuple{DynamicPPL.TypeWrap{Matrix{Float64}}}, Tuple{}, DefaultContext}, %72::TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, $(QuoteNode(DefaultContext()))::DefaultContext)::Union{}
└───       unreachable
)

julia> Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype
Union{}
@wsmoses
Copy link
Collaborator Author

wsmoses commented May 28, 2024

Some relevant code:

using Turing, Enzyme, LinearAlgebra, LogDensityProblems

using AbstractPPL
using DynamicPPL
using Accessors

using GPUCompiler
Enzyme.API.runtimeActivity!(true);

@model function demo2(::Type{TV}=Matrix{Float64}) where {TV}
    d = 2
    n = 2
    x = TV(undef, d, n)
    x[:, 1] ~ MvNormal(zeros(d), I)
    for i = 2:n
        x[:, i] ~ MvNormal(x[:, i - 1], I)
    end
end

model = demo2()
ℓ = Turing.LogDensityFunction(model)
θ = ℓ.varinfo[:]

x = θ

@show LogDensityProblems.logdensity(ℓ, x)

Enzyme.autodiff(ReverseWithPrimal, LogDensityProblems.logdensity, Active, Const(ℓ), Enzyme.Duplicated(x, zero(x)))
World = Base.get_world_counter()
FA = Const{typeof(LogDensityProblems.logdensity)}
A = Active
width = 1
Mode = Enzyme.API.DEM_ReverseModeCombined
ModifiedBetween = (false, false)
ReturnPrimal = true
ShadowInit = false
ABI = Enzyme.FFIABI
TT = Tuple{Const{LogDensityFunction{DynamicPPL.TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{AbstractPPL.VarName{:x, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, DynamicPPL.Model{typeof(demo2), (Symbol("##arg#225"),), (), (), Tuple{DynamicPPL.TypeWrap{Matrix{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.DefaultContext}}, Duplicated{Vector{Float64}}}

mi = Enzyme.Compiler.fspec(eltype(FA), TT, World)

target = Enzyme.Compiler.EnzymeTarget()
params = Enzyme.Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, Enzyme.Compiler.remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, Enzyme.Compiler.UnknownTapeType, ABI)
tmp_job    = Enzyme.Compiler.CompilerJob(mi, Enzyme.Compiler.CompilerConfig(target, params; kernel=false), World)

interp = GPUCompiler.get_interpreter(tmp_job)

spec = specialize_method(mi.def, mi.specTypes, mi.sparam_vals)
Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals)

@torfjelde
Copy link
Member

So is this a Turing.jl issue or a GPUCompiler.jl issue? Given that the type inference works nicely without GPUCompiler

@wsmoses
Copy link
Collaborator Author

wsmoses commented May 29, 2024 via email

@torfjelde
Copy link
Member

Just to clarify a bit here: there's no "bug" per-se in Turing.jl. The "bug" is just that there's no constructor for MvNormal with eltype Any, but arguably that's not a desirable thing to support. That is, this model is only expected to work when everything is type-stable.

But this constructor is only hit because GPUCompiler somehow causes an inference issue, leading to Any when every other approach correctly infers it as Float64.

@wsmoses
Copy link
Collaborator Author

wsmoses commented May 29, 2024

Sure, I'm not sure which subpackage used by turing the error is caused by.

My guess probalbly is that somewhere there is a use of a typeof(x) [aka inferred type] instead of Core.Typeof(x) [aka runtime type] which would correct the construction.

@wsmoses
Copy link
Collaborator Author

wsmoses commented May 29, 2024

cc @maleadt @vchuravy for visibility

@mhauru
Copy link
Collaborator

mhauru commented Jul 15, 2024

Gathering links to Enzyme issues that came up when trying to minimise this:
EnzymeAD/Enzyme.jl#1608
EnzymeAD/Enzyme.jl#1621
EnzymeAD/Enzyme.jl#1623
EnzymeAD/Enzyme.jl#1626

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants