diff --git a/src/Compiler.jl b/src/Compiler.jl index 9917d0291..031446a33 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -16,8 +16,7 @@ function codegen(::Val{:inplace}, path::EinExpr) args = map(Iterators.flatten([Leaves(path), Branches(path)])) do node i = ssa[node] - N = ndims(node) - :($(Symbol(:ssa, i))) #::Tensor{$T,$N}) + :($(Symbol(:ssa, i))) end ssa_eincodes = map(Branches(path)) do branch @@ -34,11 +33,10 @@ function codegen(::Val{:inplace}, path::EinExpr) ssa[branch] = k - # WARN hardcoded return type return :(contract!($ssa_c, $ssa_a, $ssa_b)) end - :(function $(gensym(:contract_compiled))($(args...)) # FIX this is hardcoded + :(function $(gensym(:contract_compiled))($(args...)) $(ssa_eincodes...) return $(Symbol(:ssa, ssa[path])) end) @@ -69,7 +67,7 @@ function codegen(::Val{:outplace}, path::EinExpr) return :($ssa_c = contract($ssa_a, $ssa_b)) end - :(function $(gensym(:contract_compiled))($(args...)) # FIX this is hardcoded + :(function $(gensym(:contract_compiled))($(args...)) $(ssa_eincodes...) return $(Symbol(:ssa, ssa[path])) end)