diff --git a/src/interface.jl b/src/interface.jl index 1346c858..2761c37b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -181,12 +181,12 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false # provide a specific interpreter to use. if VERSION >= v"1.11.0-DEV.1552" get_interpreter(@nospecialize(job::CompilerJob)) = - GPUInterpreter(job.world; method_table=method_table(job), + GPUInterpreter(job.world; meta=nothing, method_table=method_table(job), token=ci_cache_token(job), inf_params=inference_params(job), opt_params=optimization_params(job)) else get_interpreter(@nospecialize(job::CompilerJob)) = - GPUInterpreter(job.world; method_table=method_table(job), + GPUInterpreter(job.world; meta=nothing, method_table=method_table(job), code_cache=ci_cache(job), inf_params=inference_params(job), opt_params=optimization_params(job)) end diff --git a/src/jlgen.jl b/src/jlgen.jl index c6be8c94..0e298b6a 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -320,6 +320,7 @@ else end struct GPUInterpreter <: CC.AbstractInterpreter + meta::Any world::UInt method_table::GPUMethodTableView @@ -336,6 +337,7 @@ end @static if HAS_INTEGRATED_CACHE function GPUInterpreter(world::UInt=Base.get_world_counter(); + meta = nothing, method_table::MTType, token::Any, inf_params::CC.InferenceParams, @@ -345,19 +347,20 @@ function GPUInterpreter(world::UInt=Base.get_world_counter(); method_table = get_method_table_view(world, method_table) inf_cache = Vector{CC.InferenceResult}() - return GPUInterpreter(world, method_table, + return GPUInterpreter(meta, world, method_table, token, inf_cache, inf_params, opt_params) end function GPUInterpreter(interp::GPUInterpreter; + meta=interp.meta, world::UInt=interp.world, method_table::GPUMethodTableView=interp.method_table, token::Any=interp.token, inf_cache::Vector{CC.InferenceResult}=interp.inf_cache, inf_params::CC.InferenceParams=interp.inf_params, opt_params::CC.OptimizationParams=interp.opt_params) - return GPUInterpreter(world, method_table, + return GPUInterpreter(meta, world, method_table, token, inf_cache, inf_params, opt_params) end @@ -365,6 +368,7 @@ end else function GPUInterpreter(world::UInt=Base.get_world_counter(); + meta=nothing, method_table::MTType, code_cache::CodeCache, inf_params::CC.InferenceParams, @@ -374,19 +378,20 @@ function GPUInterpreter(world::UInt=Base.get_world_counter(); method_table = get_method_table_view(world, method_table) inf_cache = Vector{CC.InferenceResult}() - return GPUInterpreter(world, method_table, + return GPUInterpreter(meta, world, method_table, code_cache, inf_cache, inf_params, opt_params) end function GPUInterpreter(interp::GPUInterpreter; + meta=interp.meta, world::UInt=interp.world, method_table::GPUMethodTableView=interp.method_table, code_cache::CodeCache=interp.code_cache, inf_cache::Vector{CC.InferenceResult}=interp.inf_cache, inf_params::CC.InferenceParams=interp.inf_params, opt_params::CC.OptimizationParams=interp.opt_params) - return GPUInterpreter(world, method_table, + return GPUInterpreter(meta, world, method_table, code_cache, inf_cache, inf_params, opt_params) end @@ -437,6 +442,8 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter, end +within_gpucompiler() = false + ## deferred compilation struct DeferredCallInfo <: CC.CallInfo @@ -445,9 +452,11 @@ struct DeferredCallInfo <: CC.CallInfo end # recognize calls to gpuc.deferred and save DeferredCallInfo metadata -function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f), - arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState, - max_methods::Int = CC.get_max_methods(interp, f, sv)) +# default implementation, extensible through meta argument. +# XXX: (or should we dispatch on `f`)? +function abstract_call_known(meta::Nothing, interp::GPUInterpreter, @nospecialize(f), + arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState, + max_methods::Int = CC.get_max_methods(interp, f, sv)) (; fargs, argtypes) = arginfo if f === var"gpuc.deferred" argvec = argtypes[2:end] @@ -458,7 +467,34 @@ function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f), else return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo) end + elseif f === within_gpucompiler + if length(argtypes) != 1 + @static if VERSION < v"1.11.0-" + return CC.CallMeta(Union{}, CC.Effects(), CC.NoCallInfo()) + else + return CC.CallMeta(Union{}, Union{}, CC.Effects(), CC.NoCallInfo()) + end + end + @static if VERSION < v"1.11.0-" + return CC.CallMeta(Core.Const(true), CC.EFFECTS_TOTAL, CC.MethodResultPure()) + else + return CC.CallMeta(Core.Const(true), Union{}, CC.EFFECTS_TOTAL, CC.MethodResultPure(),) + end end + return nothing +end + +function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f), + arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState, + max_methods::Int = CC.get_max_methods(interp, f, sv)) + candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods) + if candidate === nothing && interp.meta !== nothing + candidate = abstract_call_known(interp.meta, interp, f, arginfo, si, sv, max_methods) + end + if candidate !== nothing + return candidate + end + return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f, arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState, max_methods::Int) @@ -542,6 +578,101 @@ function CC.finish(interp::GPUInterpreter, opt::CC.OptimizationState, ir::CC.IRC end end +import .CC: CallInfo +struct NoInlineCallInfo <: CallInfo + info::CallInfo # wrapped call + tt::Any # ::Type + kind::Symbol + NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) = + new(info, tt, kind) +end + +CC.nsplit_impl(info::NoInlineCallInfo) = CC.nsplit(info.info) +CC.getsplit_impl(info::NoInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx) +CC.getresult_impl(info::NoInlineCallInfo, idx::Int) = CC.getresult(info.info, idx) +struct AlwaysInlineCallInfo <: CallInfo + info::CallInfo # wrapped call + tt::Any # ::Type + AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt) +end + +CC.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info) +CC.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx) +CC.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = CC.getresult(info.info, idx) + + +function inlining_handler(meta::Nothing, interp::GPUInterpreter, @nospecialize(atype), callinfo) + return nothing +end + +using Core.Compiler: ArgInfo, StmtInfo, AbsIntState +function CC.abstract_call_gf_by_type(interp::GPUInterpreter, @nospecialize(f), arginfo::ArgInfo, + si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int) + ret = @invoke CC.abstract_call_gf_by_type(interp::CC.AbstractInterpreter, f::Any, arginfo::ArgInfo, + si::StmtInfo, atype::Any, sv::AbsIntState, max_methods::Int) + + callinfo = nothing + if interp.meta !== nothing + callinfo = inlining_handler(interp.meta, interp, atype, ret.info) + end + if callinfo === nothing + callinfo = inlining_handler(nothing, interp, atype, ret.info) + end + if callinfo === nothing + callinfo = ret.info + end + + @static if VERSION ≥ v"1.11-" + return CC.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) + else + return CC.CallMeta(ret.rt, ret.effects, callinfo) + end +end + +let # overload `inlining_policy` + @static if VERSION ≥ v"1.11.0-DEV.879" + sigs_ex = :( + interp::GPUInterpreter, + @nospecialize(src), + @nospecialize(info::CC.CallInfo), + stmt_flag::UInt32, + ) + args_ex = :( + interp::CC.AbstractInterpreter, + src::Any, + info::CC.CallInfo, + stmt_flag::UInt32, + ) + else + sigs_ex = :( + interp::GPUInterpreter, + @nospecialize(src), + @nospecialize(info::CC.CallInfo), + stmt_flag::UInt8, + mi::MethodInstance, + argtypes::Vector{Any}, + ) + args_ex = :( + interp::CC.AbstractInterpreter, + src::Any, + info::CC.CallInfo, + stmt_flag::UInt8, + mi::MethodInstance, + argtypes::Vector{Any}, + ) + end + @eval function Core.Compiler.inlining_policy($(sigs_ex.args...)) + if info isa NoInlineCallInfo + @safe_debug "Blocking inlining" info.tt info.kind + return nothing + elseif info isa AlwaysInlineCallInfo + @safe_debug "Forcing inlining for" info.tt + return src + end + return @invoke Core.Compiler.inlining_policy($(args_ex.args...)) + end +end + ## world view of the cache using Core.Compiler: WorldView @@ -704,7 +835,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) source = pop!(worklist) haskey(compiled, source) && continue # We have fulfilled the request already # Create a new compiler job for this edge, reusing the config settings from the inital one - job2 = CompilerJob(source, job.config) + job2 = CompilerJob(source, job.config) # TODO: GPUInterpreter.meta in config? llvm_mod2, outstanding = compile_method_instance(job2, compiled) append!(worklist, outstanding) # merge worklist with new outstanding edges @assert context(llvm_mod) == context(llvm_mod2) diff --git a/test/ptx_tests.jl b/test/ptx_tests.jl index 40f9eea9..57843ab1 100644 --- a/test/ptx_tests.jl +++ b/test/ptx_tests.jl @@ -288,6 +288,14 @@ end @test !occursin("gpucompiler.mark", ir) end +@testset "within_gpucompiler" begin + function kernel(a) + unsafe_store!(a, GPUCompiler.within_gpucompiler()) + end + ir = sprint(io->code_llvm(io, kernel, Tuple{Int})) + @show ir +end + @testset "exception arguments" begin function kernel(a) unsafe_store!(a, trunc(Int, unsafe_load(a)))