Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experiment] Add deferred_with #599

Closed
wants to merge 15 commits into from
78 changes: 75 additions & 3 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob);
error("Unknown compilation output $output")
end

# GPUCompiler intrinsic that marks deferred compilation
function var"gpuc.deferred" end

# GPUCompiler intrinsic that marks deferred compilation, across backends
function var"gpuc.deferred.with" end

# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
# this could both be generalized (e.g. supporting actual function calls, instead of
# returning a function pointer), and be integrated with the nonrecursive codegen.
Expand Down Expand Up @@ -157,6 +163,29 @@ end
end
end

function find_base_object(val)
while true
if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr ||
opcode(val) == LLVM.API.LLVMBitCast ||
opcode(val) == LLVM.API.LLVMAddrSpaceCast)
val = first(operands(val))
elseif val isa LLVM.IntToPtrInst || val isa LLVM.BitCastInst || val isa LLVM.AddrSpaceCastInst
val = first(operands(val))
elseif val isa LLVM.LoadInst
# In 1.11+ we no longer embed integer constants directly.
gv = first(operands(val))
if gv isa LLVM.GlobalValue
val = LLVM.initializer(gv)
continue
end
break
else
break
end
end
return val
end

const __llvm_initialized = Ref(false)

@locked function emit_llvm(@nospecialize(job::CompilerJob);
Expand Down Expand Up @@ -186,8 +215,8 @@ const __llvm_initialized = Ref(false)
entry = finish_module!(job, ir, entry)

# deferred code generation
has_deferred_jobs = !only_entry && toplevel &&
haskey(functions(ir), "deferred_codegen")
has_deferred_jobs = !only_entry && toplevel && haskey(functions(ir), "deferred_codegen")

jobs = Dict{CompilerJob, String}(job => entry_fn)
if has_deferred_jobs
dyn_marker = functions(ir)["deferred_codegen"]
Expand All @@ -198,7 +227,6 @@ const __llvm_initialized = Ref(false)
changed = false

# find deferred compiler
# TODO: recover this information earlier, from the Julia IR
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
for use in uses(dyn_marker)
# decode the call
Expand Down Expand Up @@ -260,6 +288,50 @@ const __llvm_initialized = Ref(false)
end

# all deferred compilations should have been resolved
if dyn_marker !== nothing
@compiler_assert isempty(uses(dyn_marker)) job
unsafe_delete!(ir, dyn_marker)
end
end

if haskey(functions(ir), "gpuc.lookup")
dyn_marker = functions(ir)["gpuc.lookup"]

worklist = Dict{Any, Vector{LLVM.CallInst}}()
for use in uses(dyn_marker)
# decode the call
call = user(use)::LLVM.CallInst
dyn_mi_inst = find_base_object(operands(call)[1])
@compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job
dyn_mi = Base.unsafe_pointer_to_objref(
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
end

for dyn_mi in keys(worklist)
dyn_fn_name = compiled[dyn_mi].specfunc
dyn_fn = functions(ir)[dyn_fn_name]

# insert a pointer to the function everywhere the entry is used
T_ptr = convert(LLVMType, Ptr{Cvoid})
for call in worklist[dyn_mi]
@dispose builder=IRBuilder() begin
position!(builder, call)
fptr = if LLVM.version() >= v"17"
T_ptr = LLVM.PointerType()
bitcast!(builder, dyn_fn, T_ptr)
elseif VERSION >= v"1.12.0-DEV.225"
T_ptr = LLVM.PointerType(LLVM.Int8Type())
bitcast!(builder, dyn_fn, T_ptr)
else
ptrtoint!(builder, dyn_fn, T_ptr)
end
replace_uses!(call, fptr)
end
unsafe_delete!(LLVM.parent(call), call)
end
end

@compiler_assert isempty(uses(dyn_marker)) job
unsafe_delete!(ir, dyn_marker)
end
Expand Down
157 changes: 153 additions & 4 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ else
get_method_table_view(world::UInt, mt::MTType) = OverlayMethodTable(world, mt)
end

struct GPUInterpreter <: CC.AbstractInterpreter
abstract type AbstractGPUInterpreter <: CC.AbstractInterpreter end
struct GPUInterpreter <: AbstractGPUInterpreter
world::UInt
method_table::GPUMethodTableView

Expand Down Expand Up @@ -435,6 +436,113 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter,
return ret
end

struct DeferredCallInfo <: CC.CallInfo
rt::DataType
info::CC.CallInfo
end

function CC.abstract_call_known(interp::AbstractGPUInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int = CC.get_max_methods(interp, f, sv))
(; fargs, argtypes) = arginfo
if f === var"gpuc.deferred" || f === var"gpuc.deferred.with"
first_arg = f === var"gpuc.deferred" ? 2 : 3
argvec = argtypes[first_arg:end]
call = CC.abstract_call(interp, CC.ArgInfo(nothing, argvec), si, sv, max_methods)
callinfo = DeferredCallInfo(call.rt, call.info)
@static if VERSION < v"1.11.0-"
return CC.CallMeta(Ptr{Cvoid}, CC.Effects(), callinfo)
else
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo)
end
end
return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f,
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int)
end

# Use the Inlining infrastructure to perform our refinement
const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8
function CC.handle_call!(todo::Vector{Pair{Int,Any}},
ir::CC.IRCode, idx::CC.Int, stmt::Expr, info::DeferredCallInfo, flag::FlagType, sig::CC.Signature,
state::CC.InliningState)

minfo = info.info
results = minfo.results
if length(results.matches) != 1
return nothing
end
match = only(results.matches)

# lookup the target mi with correct edge tracking
case = CC.compileable_specialization(match, CC.Effects(), CC.InliningEdgeTracker(state), info)
@assert case isa CC.InvokeCase
@assert stmt.head === :call

f = stmt.args[1]
name = f === var"gpuc.deferred" ? "extern gpuc.lookup" : "extern gpuc.lookup.with"
with_arg_T = f === var"gpuc.deferred" ? () : (Any,)

args = Any[
name,
Ptr{Cvoid},
Core.svec(Any, Any, with_arg_T..., match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype
0,
QuoteNode(:llvmcall),
case.invoke,
stmt.args[2:end]...
]
stmt.head = :foreigncall
stmt.args = args
return nothing
end

struct DeferredEdges
edges::Vector{MethodInstance}
end

function find_deferred_edges(ir::CC.IRCode)
edges = MethodInstance[]
# @aviatesk: Can we add this instead in handle_call
for stmt in ir.stmts
inst = stmt[:inst]
inst isa Expr || continue
expr = inst::Expr
if expr.head === :foreigncall &&
expr.args[1] == "extern gpuc.lookup"
deferred_mi = expr.args[6]
push!(edges, deferred_mi)
elseif expr.head === :foreigncall &&
expr.args[1] == "extern gpuc.lookup.with"
deferred_mi = expr.args[6]
with = expr.args[7]
@show (deferred_mi, with)
end
end
unique!(edges)
return edges
end

if VERSION >= v"1.11.0-"
# stack_analysis_result and ipo_dataflow_analysis is 1.11 only
function CC.ipo_dataflow_analysis!(interp::AbstractGPUInterpreter, ir::CC.IRCode, caller::CC.InferenceResult)
edges = find_deferred_edges(ir)
if !isempty(edges)
CC.stack_analysis_result!(caller, DeferredEdges(edges))
end
@invoke CC.ipo_dataflow_analysis!(interp::CC.AbstractInterpreter, ir::CC.IRCode, caller::CC.InferenceResult)
end
else
# v1.10.0
function CC.finish(interp::AbstractGPUInterpreter, opt::CC.OptimizationState, ir::CC.IRCode, caller::CC.InferenceResult)
edges = find_deferred_edges(ir)
if !isempty(edges)
# This is a tad bit risky, but nobody should be running EA on our results.
caller.argescapes = DeferredEdges(edges)
end
@invoke CC.finish(interp::CC.AbstractInterpreter, opt::CC.OptimizationState, ir::CC.IRCode, caller::CC.InferenceResult)
end
end

## world view of the cache
using Core.Compiler: WorldView
Expand Down Expand Up @@ -584,6 +692,24 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
error("Cannot compile $(job.source) for world $(job.world); method is only valid in worlds $(job.source.def.primary_world) to $(job.source.def.deleted_world)")
end

compiled = IdDict()
llvm_mod, outstanding = compile_method_instance(job, compiled)
worklist = outstanding
while !isempty(worklist)
source = pop!(worklist)
haskey(compiled, source) && continue
job2 = CompilerJob(source, job.config)
@debug "Processing..." job2
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
append!(worklist, outstanding)
@assert context(llvm_mod) == context(llvm_mod2)
link!(llvm_mod, llvm_mod2)
end

return llvm_mod, compiled
end

function compile_method_instance(@nospecialize(job::CompilerJob), compiled::IdDict{Any, Any})
# populate the cache
interp = get_interpreter(job)
cache = CC.code_cache(interp)
Expand All @@ -594,7 +720,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))

# create a callback to look-up function in our cache,
# and keep track of the method instances we needed.
method_instances = []
method_instances = Any[]
if Sys.ARCH == :x86 || Sys.ARCH == :x86_64
function lookup_fun(mi, min_world, max_world)
push!(method_instances, mi)
Expand Down Expand Up @@ -659,7 +785,6 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
end

# process all compiled method instances
compiled = Dict()
for mi in method_instances
ci = ci_cache_lookup(cache, mi, job.world, job.world)
ci === nothing && continue
Expand Down Expand Up @@ -696,10 +821,34 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
compiled[mi] = (; ci, func=llvm_func, specfunc=llvm_specfunc)
end

# Collect the deferred edges
outstanding = Any[]
for mi in method_instances
!haskey(compiled, mi) && continue # Equivalent to ci_cache_lookup == nothing
ci = compiled[mi].ci
@static if VERSION >= v"1.11.0-"
edges = CC.traverse_analysis_results(ci) do @nospecialize result
return result isa DeferredEdges ? result : return
end
else
edges = ci.argescapes
if !(edges isa Union{Nothing, DeferredEdges})
edges = nothing
end
end
if edges !== nothing
for deferred_mi in (edges::DeferredEdges).edges
if !haskey(compiled, deferred_mi)
push!(outstanding, deferred_mi)
end
end
end
end

# ensure that the requested method instance was compiled
@assert haskey(compiled, job.source)

return llvm_mod, compiled
return llvm_mod, outstanding
end

# partially revert JuliaLangjulia#49391
Expand Down
14 changes: 14 additions & 0 deletions test/native_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,20 @@ end
ir = fetch(t)
@test contains(ir, r"add i64 %\d+, 3")
end

@testset "deferred" begin
@gensym child kernel unrelated
@eval @noinline $child(i) = i
@eval $kernel(i) = GPUCompiler.var"gpuc.deferred"($child, i)

# smoke test
job, _ = Native.create_job(eval(kernel), (Int64,))

ci, rt = only(GPUCompiler.code_typed(job))
@test rt === Ptr{Cvoid}

ir = sprint(io->GPUCompiler.code_llvm(io, job))
end
end

############################################################################################
Expand Down
Loading