Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean code #132

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ end

function compile_mlir(f, args; kwargs...)
ctx = MLIR.IR.Context()
Base.append!(Reactant.registry[]; context=ctx)
append!(Reactant.MLIR.registry[]; context=ctx)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
MLIR.IR.context!(ctx) do
mod = MLIR.IR.Module(MLIR.IR.Location())
Expand Down Expand Up @@ -567,7 +567,7 @@ end
function compile_xla(f, args; client=nothing)
# register MLIR dialects
ctx = MLIR.IR.Context()
Base.append!(Reactant.registry[]; context=ctx)
append!(Reactant.MLIR.registry[]; context=ctx)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

return MLIR.IR.context!(ctx) do
Expand Down
209 changes: 209 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
using Enzyme

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
)::RT where {copy_if_inactive,RT<:RArray}
if haskey(seen, prev)
return seen[prev]
end
if Enzyme.Compiler.guaranteed_const_nongen(RT, nothing)
return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev
end
if RT <: ConcreteRArray
res = RT(zeros(eltype(RT), size(prev)))
seen[prev] = res
return res
end

if RT <: TracedRArray
res = broadcast_to_size(eltype(RT)(0), size(prev))
seen[prev] = res
return res
end

attr = fill(MLIR.IR.Attribute(eltype(RT)(0)), mlir_type(prev))
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
res = RT((), cst)
seen[prev] = res
return res
end

@reactant_override function Enzyme.autodiff(
::CMode, f::FA, ::Type{A}, args::Vararg{Enzyme.Annotation,Nargs}
) where {CMode<:Enzyme.Mode,FA<:Enzyme.Annotation,A<:Enzyme.Annotation,Nargs}
reverse = CMode <: Enzyme.ReverseMode

primf = f.val
primargs = ((v.val for v in args)...,)

fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(
primf, primargs, (), string(f) * "_autodiff", false
)

activity = Int32[]
ad_inputs = MLIR.IR.Value[]

for a in linear_args
idx, path = get_argidx(a)
if idx == 1 && fnwrap
push!(activity, act_from_type(f, reverse))
push_acts!(ad_inputs, f, path[3:end], reverse)
else
if fnwrap
idx -= 1
end
push!(activity, act_from_type(args[idx], reverse))
push_acts!(ad_inputs, args[idx], path[3:end], reverse)
end
end

outtys = MLIR.IR.Type[]
@inline needs_primal(::Type{<:Enzyme.ReverseMode{ReturnPrimal}}) where {ReturnPrimal} =
ReturnPrimal
for a in linear_results
if has_residx(a)
if needs_primal(CMode)
push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data)))
end
else
push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data)))
end
end
for (i, act) in enumerate(activity)
if act == enzyme_out || (reverse && (act == enzyme_dup || act == enzyme_dupnoneed))
push!(outtys, in_tys[i])# transpose_ty(MLIR.IR.type(MLIR.IR.operand(ret, i))))
end
end

ret_activity = Int32[]
for a in linear_results
if has_residx(a)
act = act_from_type(A, reverse, needs_primal(CMode))
push!(ret_activity, act)
if act == enzyme_out || act == enzyme_outnoneed
attr = fill(MLIR.IR.Attribute(eltype(a)(1)), mlir_type(a))
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
push!(ad_inputs, cst)
end
else
idx, path = get_argidx(a)
if idx == 1 && fnwrap
act = act_from_type(f, reverse, true)
push!(ret_activity, act)
if act != enzyme_out && act != enzyme_outnoneed
continue
end
push_val!(ad_inputs, f.dval, path[3:end])
else
if fnwrap
idx -= 1
end
act = act_from_type(args[idx], reverse, true)
push!(ret_activity, act)
if act != enzyme_out && act != enzyme_outnoneed
continue
end
push_val!(ad_inputs, args[idx].dval, path[3:end])
end
end
end

function act_attr(val)
val = @ccall MLIR.API.mlir_c.enzymeActivityAttrGet(
MLIR.IR.context()::MLIR.API.MlirContext, val::Int32
)::MLIR.API.MlirAttribute
return MLIR.IR.Attribute(val)
end
fname = get_attribute_by_name(func2, "sym_name")
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)(
[transpose_val(v) for v in ad_inputs];
outputs=outtys,
fn=fname,
activity=MLIR.IR.Attribute([act_attr(a) for a in activity]),
ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]),
)

residx = 1

for a in linear_results
if has_residx(a)
if needs_primal(CMode)
path = get_residx(a)
set!(result, path[2:end], transpose_val(MLIR.IR.result(res, residx)))
residx += 1
end
else
idx, path = get_argidx(a)
if idx == 1 && fnwrap
set!(f.val, path[3:end], transpose_val(MLIR.IR.result(res, residx)))
residx += 1
else
if fnwrap
idx -= 1
end
set!(args[idx].val, path[3:end], transpose_val(MLIR.IR.result(res, residx)))
residx += 1
end
end
end

restup = Any[(a isa Active) ? copy(a) : nothing for a in args]
for a in linear_args
idx, path = get_argidx(a)
if idx == 1 && fnwrap
if act_from_type(f, reverse) != enzyme_out
continue
end
if f isa Enzyme.Active
@assert false
residx += 1
continue
end
set_act!(f, path[3:end], reverse, transpose_val(MLIR.IR.result(res, residx)))
else
if fnwrap
idx -= 1
end
if act_from_type(args[idx], reverse) != enzyme_out
continue
end
if args[idx] isa Enzyme.Active
set_act!(
args[idx],
path[3:end],
false,
transpose_val(MLIR.IR.result(res, residx));
emptypaths=true,
) #=reverse=#
residx += 1
continue
end
set_act!(
args[idx], path[3:end], reverse, transpose_val(MLIR.IR.result(res, residx))
)
end
residx += 1
end

func2.operation = MLIR.API.MlirOperation(C_NULL)

if reverse
resv = if needs_primal(CMode)
result
else
nothing
end
return ((restup...,), resv)
else
if A <: Const
return result
else
dres = copy(result)
throw(AssertionError("TODO implement forward mode handler"))
if A <: Duplicated
return ()
end
end
end
end
Loading
Loading