Skip to content

Commit

Permalink
Refactor codegen methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed May 18, 2024
1 parent cd5f116 commit 611c642
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ function codegen(::Val{:inplace}, path::EinExpr)

args = map(Iterators.flatten([Leaves(path), Branches(path)])) do node
i = ssa[node]
N = ndims(node)
:($(Symbol(:ssa, i))) #::Tensor{$T,$N})
:($(Symbol(:ssa, i)))
end

ssa_eincodes = map(Branches(path)) do branch
Expand All @@ -34,11 +33,10 @@ function codegen(::Val{:inplace}, path::EinExpr)

ssa[branch] = k

# WARN hardcoded return type
return :(contract!($ssa_c, $ssa_a, $ssa_b))
end

:(function $(gensym(:contract_compiled))($(args...)) # FIX this is hardcoded
:(function $(gensym(:contract_compiled))($(args...))
$(ssa_eincodes...)
return $(Symbol(:ssa, ssa[path]))
end)
Expand Down Expand Up @@ -69,7 +67,7 @@ function codegen(::Val{:outplace}, path::EinExpr)
return :($ssa_c = contract($ssa_a, $ssa_b))
end

:(function $(gensym(:contract_compiled))($(args...)) # FIX this is hardcoded
:(function $(gensym(:contract_compiled))($(args...))
$(ssa_eincodes...)
return $(Symbol(:ssa, ssa[path]))
end)
Expand Down

0 comments on commit 611c642

Please sign in to comment.