Skip to content

Commit

Permalink
Make GPUInterpreter extensible
Browse files Browse the repository at this point in the history
Currently Enzyme uses it's own AbstractInterpreter, in particular to
handle inlining blocking of functions with custom rules and to handle
nested autodiff operations.

- [ ] Create a version of Enzyme with this
- [ ] Support a version of `gpuc.deferred(meta)`
  • Loading branch information
vchuravy committed Sep 27, 2024
1 parent e9d1372 commit 6f8532b
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,12 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
# provide a specific interpreter to use.
if VERSION >= v"1.11.0-DEV.1552"
get_interpreter(@nospecialize(job::CompilerJob)) =
GPUInterpreter(job.world; method_table=method_table(job),
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
token=ci_cache_token(job), inf_params=inference_params(job),
opt_params=optimization_params(job))
else
get_interpreter(@nospecialize(job::CompilerJob)) =
GPUInterpreter(job.world; method_table=method_table(job),
GPUInterpreter(job.world; meta=nothing, method_table=method_table(job),
code_cache=ci_cache(job), inf_params=inference_params(job),
opt_params=optimization_params(job))
end
Expand Down
162 changes: 154 additions & 8 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ else
end

struct GPUInterpreter <: CC.AbstractInterpreter
meta::Any
world::UInt
method_table::GPUMethodTableView

Expand All @@ -336,6 +337,7 @@ end

@static if HAS_INTEGRATED_CACHE
function GPUInterpreter(world::UInt=Base.get_world_counter();
meta = nothing,
method_table::MTType,
token::Any,
inf_params::CC.InferenceParams,
Expand All @@ -345,26 +347,28 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
method_table = get_method_table_view(world, method_table)
inf_cache = Vector{CC.InferenceResult}()

return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
token, inf_cache,
inf_params, opt_params)
end

function GPUInterpreter(interp::GPUInterpreter;
meta=interp.meta,
world::UInt=interp.world,
method_table::GPUMethodTableView=interp.method_table,
token::Any=interp.token,
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
inf_params::CC.InferenceParams=interp.inf_params,
opt_params::CC.OptimizationParams=interp.opt_params)
return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
token, inf_cache,
inf_params, opt_params)
end

else

function GPUInterpreter(world::UInt=Base.get_world_counter();
meta=nothing,
method_table::MTType,
code_cache::CodeCache,
inf_params::CC.InferenceParams,
Expand All @@ -374,19 +378,20 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
method_table = get_method_table_view(world, method_table)
inf_cache = Vector{CC.InferenceResult}()

return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
code_cache, inf_cache,
inf_params, opt_params)
end

function GPUInterpreter(interp::GPUInterpreter;
meta=interp.meta,
world::UInt=interp.world,
method_table::GPUMethodTableView=interp.method_table,
code_cache::CodeCache=interp.code_cache,
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
inf_params::CC.InferenceParams=interp.inf_params,
opt_params::CC.OptimizationParams=interp.opt_params)
return GPUInterpreter(world, method_table,
return GPUInterpreter(meta, world, method_table,
code_cache, inf_cache,
inf_params, opt_params)
end
Expand Down Expand Up @@ -437,6 +442,8 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter,
end


within_gpucompiler() = false

## deferred compilation

struct DeferredCallInfo <: CC.CallInfo
Expand All @@ -445,9 +452,11 @@ struct DeferredCallInfo <: CC.CallInfo
end

# recognize calls to gpuc.deferred and save DeferredCallInfo metadata
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int = CC.get_max_methods(interp, f, sv))
# default implementation, extensible through meta argument.
# XXX: (or should we dispatch on `f`)?
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int = CC.get_max_methods(interp, f, sv))
(; fargs, argtypes) = arginfo
if f === var"gpuc.deferred"
argvec = argtypes[2:end]
Expand All @@ -458,7 +467,34 @@ function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
else
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo)
end
elseif f === within_gpucompiler
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CC.CallMeta(Union{}, CC.Effects(), CC.NoCallInfo())
else
return CC.CallMeta(Union{}, Union{}, CC.Effects(), CC.NoCallInfo())
end
end
@static if VERSION < v"1.11.0-"
return CC.CallMeta(Core.Const(true), CC.EFFECTS_TOTAL, CC.MethodResultPure())
else
return CC.CallMeta(Core.Const(true), Union{}, CC.EFFECTS_TOTAL, CC.MethodResultPure(),)
end
end
return nothing
end

function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int = CC.get_max_methods(interp, f, sv))
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
if candidate === nothing && interp.meta !== nothing
candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods)
end
if candidate !== nothing
return candidate
end

return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f,
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int)
Expand Down Expand Up @@ -542,6 +578,116 @@ function CC.finish(interp::GPUInterpreter, opt::CC.OptimizationState, ir::CC.IRC
end
end

import .CC: CallInfo
struct NoInlineCallInfo <: CallInfo
info::CallInfo # wrapped call
tt::Any # ::Type
kind::Symbol
NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) =
new(info, tt, kind)
end

CC.nsplit_impl(info::NoInlineCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::NoInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::NoInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
struct AlwaysInlineCallInfo <: CallInfo
info::CallInfo # wrapped call
tt::Any # ::Type
AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt)
end

CC.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info)
CC.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)


function inlining_handler(meta::Nothing, interp::GPUInterpreter, @nospecialize(atype), callinfo)
return nothing
end

using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
function CC.abstract_call_gf_by_type(interp::GPUInterpreter, @nospecialize(f), arginfo::ArgInfo,
si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int)
ret = @invoke CC.abstract_call_gf_by_type(interp::CC.AbstractInterpreter, f::Any, arginfo::ArgInfo,
si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int)

callinfo = nothing
if interp.meta !== nothing
callinfo = inlining_handler(interp.meta, interp, atype, ret.info)
end
if callinfo === nothing
callinfo = inlining_handler(nothing, interp, atype, ret.info)
end
if callinfo === nothing
callinfo = ret.info
end

@static if VERSION v"1.11-"
return CC.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
else
return CC.CallMeta(ret.rt, ret.effects, callinfo)
end
end

@static if VERSION < v"1.12.0-DEV.45"
let # overload `inlining_policy`
@static if VERSION v"1.11.0-DEV.879"
sigs_ex = :(
interp::GPUInterpreter,
@nospecialize(src),
@nospecialize(info::CC.CallInfo),
stmt_flag::UInt32,
)
args_ex = :(
interp::CC.AbstractInterpreter,
src::Any,
info::CC.CallInfo,
stmt_flag::UInt32,
)
else
sigs_ex = :(
interp::GPUInterpreter,
@nospecialize(src),
@nospecialize(info::CC.CallInfo),
stmt_flag::UInt8,
mi::MethodInstance,
argtypes::Vector{Any},
)
args_ex = :(
interp::CC.AbstractInterpreter,
src::Any,
info::CC.CallInfo,
stmt_flag::UInt8,
mi::MethodInstance,
argtypes::Vector{Any},
)
end
@eval function CC.inlining_policy($(sigs_ex.args...))
if info isa NoInlineCallInfo
@safe_debug "Blocking inlining" info.tt info.kind
return nothing
elseif info isa AlwaysInlineCallInfo
@safe_debug "Forcing inlining for" info.tt
return src
end
return @invoke CC.inlining_policy($(args_ex.args...))
end
end
else
function CC.src_inlining_policy(interp::GPUInterpreter,
@nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::UInt32)

if info isa NoInlineCallInfo
@safe_debug "Blocking inlining" info.tt info.kind
return false
elseif info isa AlwaysInlineCallInfo
@safe_debug "Forcing inlining for" info.tt
return true
end
return @invoke CC.src_inlining_policy(interp::CC.AbstractInterpreter, src, info::CC.CallInfo, stmt_flag::UInt32)
end
end


## world view of the cache
using Core.Compiler: WorldView
Expand Down Expand Up @@ -704,7 +850,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
source = pop!(worklist)
haskey(compiled, source) && continue # We have fulfilled the request already
# Create a new compiler job for this edge, reusing the config settings from the inital one
job2 = CompilerJob(source, job.config)
job2 = CompilerJob(source, job.config) # TODO: GPUInterpreter.meta in config?
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
append!(worklist, outstanding) # merge worklist with new outstanding edges
@assert context(llvm_mod) == context(llvm_mod2)
Expand Down
8 changes: 8 additions & 0 deletions test/ptx_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,14 @@ end
@test !occursin("gpucompiler.mark", ir)
end

@testset "within_gpucompiler" begin
function kernel(a)
unsafe_store!(a, GPUCompiler.within_gpucompiler())
end
ir = sprint(io->code_llvm(io, kernel, Tuple{Int}))
@show ir
end

@testset "exception arguments" begin
function kernel(a)
unsafe_store!(a, trunc(Int, unsafe_load(a)))
Expand Down

0 comments on commit 6f8532b

Please sign in to comment.