Skip to content

Commit

Permalink
Make GPUInterpreter extensible
Browse files Browse the repository at this point in the history
Currently Enzyme uses it's own AbstractInterpreter, in particular to
handle inlining blocking of functions with custom rules and to handle
nested autodiff operations.

- [ ] Create a version of Enzyme with this
- [ ] Support a version of `gpuc.deferred(meta)`
  • Loading branch information
vchuravy committed Sep 26, 2024
1 parent dfd5c35 commit 64e5353
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,12 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
# provide a specific interpreter to use.
if VERSION >= v"1.11.0-DEV.1552"
get_interpreter(@nospecialize(job::CompilerJob)) =
GPUInterpreter(job.world; method_table=method_table(job),
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
token=ci_cache_token(job), inf_params=inference_params(job),
opt_params=optimization_params(job))
else
get_interpreter(@nospecialize(job::CompilerJob)) =
GPUInterpreter(job.world; method_table=method_table(job),
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
code_cache=ci_cache(job), inf_params=inference_params(job),
opt_params=optimization_params(job))
end
Expand Down
132 changes: 124 additions & 8 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ else
end

struct GPUInterpreter <: CC.AbstractInterpreter
meta::Any
world::UInt
method_table::GPUMethodTableView

Expand All @@ -336,6 +337,7 @@ end

@static if HAS_INTEGRATED_CACHE
function GPUInterpreter(world::UInt=Base.get_world_counter();
meta = nothing,
method_table::MTType,
token::Any,
inf_params::CC.InferenceParams,
Expand All @@ -345,26 +347,28 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
method_table = get_method_table_view(world, method_table)
inf_cache = Vector{CC.InferenceResult}()

return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
token, inf_cache,
inf_params, opt_params)
end

function GPUInterpreter(interp::GPUInterpreter;
meta=interp.meta,
world::UInt=interp.world,
method_table::GPUMethodTableView=interp.method_table,
token::Any=interp.token,
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
inf_params::CC.InferenceParams=interp.inf_params,
opt_params::CC.OptimizationParams=interp.opt_params)
return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
token, inf_cache,
inf_params, opt_params)
end

else

function GPUInterpreter(world::UInt=Base.get_world_counter();
meta=nothing,
method_table::MTType,
code_cache::CodeCache,
inf_params::CC.InferenceParams,
Expand All @@ -374,19 +378,20 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
method_table = get_method_table_view(world, method_table)
inf_cache = Vector{CC.InferenceResult}()

return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
code_cache, inf_cache,
inf_params, opt_params)
end

function GPUInterpreter(interp::GPUInterpreter;
meta=interp.meta,
world::UInt=interp.world,
method_table::GPUMethodTableView=interp.method_table,
code_cache::CodeCache=interp.code_cache,
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
inf_params::CC.InferenceParams=interp.inf_params,
opt_params::CC.OptimizationParams=interp.opt_params)
return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
code_cache, inf_cache,
inf_params, opt_params)
end
Expand Down Expand Up @@ -445,9 +450,11 @@ struct DeferredCallInfo <: CC.CallInfo
end

# recognize calls to gpuc.deferred and save DeferredCallInfo metadata
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int = CC.get_max_methods(interp, f, sv))
# default implementation, extensible through meta argument.
# XXX: (or should we dispatch on `f`)?
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int = CC.get_max_methods(interp, f, sv))
(; fargs, argtypes) = arginfo
if f === var"gpuc.deferred"
argvec = argtypes[2:end]
Expand All @@ -459,6 +466,20 @@ function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo)
end
end
return nothing
end

function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int = CC.get_max_methods(interp, f, sv))
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
if candidate === nothing && interp.meta !== nothing
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
end
if candidate !== nothing
return candidate
end

return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f,
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int)
Expand Down Expand Up @@ -542,6 +563,101 @@ function CC.finish(interp::GPUInterpreter, opt::CC.OptimizationState, ir::CC.IRC
end
end

import .CC: CallInfo
struct NoInlineCallInfo <: CallInfo
info::CallInfo # wrapped call
tt::Any # ::Type
kind::Symbol
NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) =
new(info, tt, kind)
end

CC.nsplit_impl(info::NoInlineCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::NoInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::NoInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
struct AlwaysInlineCallInfo <: CallInfo
info::CallInfo # wrapped call
tt::Any # ::Type
AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt)
end

CC.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info)
CC.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)


function inlining_handler(meta::Nothing, interp::GPUCompiler, @nospecialize(atype), callinfo)
return nothing
end

using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
function CC.abstract_call_gf_by_type(interp::GPUCompiler, @nospecialize(f), arginfo::ArgInfo,
si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int)
ret = @invoke CC.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any, arginfo::ArgInfo,
si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int)

callinfo = nothing
if interp.meta !== nothing
callinfo = inlining_handler(interp.meta, interp, atype, ret.callinfo)
end
if callinfo === nothing
callinfo = inlining_handler(nothing, interp, atype, ret.callinfo)
end
if callinfo === nothing
callinfo = ret.callinfo
end

@static if VERSION v"1.11-"
return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
else
return Core.Compiler.CallMeta(ret.rt, ret.effects, callinfo)
end
end

let # overload `inlining_policy`
@static if VERSION v"1.11.0-DEV.879"
sigs_ex = :(
interp::GPUInterpreter,
@nospecialize(src),
@nospecialize(info::Core.Compiler.CallInfo),
stmt_flag::UInt32,
)
args_ex = :(
interp::AbstractInterpreter,
src::Any,
info::Core.Compiler.CallInfo,
stmt_flag::UInt32,
)
else
sigs_ex = :(
interp::GPUInterpreter,
@nospecialize(src),
@nospecialize(info::Core.Compiler.CallInfo),
stmt_flag::UInt8,
mi::MethodInstance,
argtypes::Vector{Any},
)
args_ex = :(
interp::AbstractInterpreter,
src::Any,
info::Core.Compiler.CallInfo,
stmt_flag::UInt8,
mi::MethodInstance,
argtypes::Vector{Any},
)
end
@eval function Core.Compiler.inlining_policy($(sigs_ex.args...))
if info isa NoInlineCallInfo
@safe_debug "Blocking inlining" info.tt info.kind
return nothing
elseif info isa AlwaysInlineCallInfo
@safe_debug "Forcing inlining for" info.tt
return src
end
return @invoke Core.Compiler.inlining_policy($(args_ex.args...))
end
end


## world view of the cache
using Core.Compiler: WorldView
Expand Down Expand Up @@ -704,7 +820,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
source = pop!(worklist)
haskey(compiled, source) && continue # We have fulfilled the request already
# Create a new compiler job for this edge, reusing the config settings from the inital one
job2 = CompilerJob(source, job.config)
job2 = CompilerJob(source, job.config) # TODO: GPUInterpreter.meta in config?
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
append!(worklist, outstanding) # merge worklist with new outstanding edges
@assert context(llvm_mod) == context(llvm_mod2)
Expand Down

0 comments on commit 64e5353

Please sign in to comment.