Skip to content

Commit

Permalink
Rework deferred compilation mechanism.
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Sep 26, 2024
1 parent 316668b commit a018c17
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 129 deletions.
55 changes: 28 additions & 27 deletions examples/jit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,31 +116,31 @@ function get_trampoline(job)
return addr
end

import GPUCompiler: deferred_codegen_jobs
@generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
# manual version of native_job because we have a function type
source = methodinstance(F, Base.to_tuple_type(tt), world)
target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
# XXX: do we actually require the Julia runtime?
# with jlruntime=false, we reach an unreachable.
params = TestCompilerParams()
config = CompilerConfig(target, params; kernel=false)
job = CompilerJob(source, config, world)
# XXX: invoking GPUCompiler from a generated function is not allowed!
# for things to work, we need to forward the correct world, at least.

addr = get_trampoline(job)
trampoline = pointer(addr)
id = Base.reinterpret(Int, trampoline)

deferred_codegen_jobs[id] = job

quote
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
assume(ptr != C_NULL)
return ptr
end
end
# import GPUCompiler: deferred_codegen_jobs
# @generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
# # manual version of native_job because we have a function type
# source = methodinstance(F, Base.to_tuple_type(tt), world)
# target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
# # XXX: do we actually require the Julia runtime?
# # with jlruntime=false, we reach an unreachable.
# params = TestCompilerParams()
# config = CompilerConfig(target, params; kernel=false)
# job = CompilerJob(source, config, world)
# # XXX: invoking GPUCompiler from a generated function is not allowed!
# # for things to work, we need to forward the correct world, at least.

# addr = get_trampoline(job)
# trampoline = pointer(addr)
# id = Base.reinterpret(Int, trampoline)

# deferred_codegen_jobs[id] = job

# quote
# ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
# assume(ptr != C_NULL)
# return ptr
# end
# end

@generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N}
argtt = tt.parameters[1]
Expand Down Expand Up @@ -224,8 +224,9 @@ end
@inline function call_delayed(f::F, args...) where F
tt = Tuple{map(Core.Typeof, args)...}
rt = Core.Compiler.return_type(f, tt)
world = GPUCompiler.tls_world_age()
ptr = deferred_codegen(f, Val(tt), Val(world))
# FIXME: Horrible idea, have `var"gpuc.deferred"` actually do the work
# But that will only be needed here, and in Enzyme...
ptr = GPUCompiler.var"gpuc.deferred"(f, args...)
abi_call(ptr, rt, tt, f, args...)
end

Expand Down
175 changes: 78 additions & 97 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ function JuliaContext(f; kwargs...)
end


## deferred compilation

"""
var"gpuc.deferred"(f, args...)::Ptr{Cvoid}
As if we were to call `f(args...)` but instead we are
putting down a marker and return a function pointer to later
call.
"""
function var"gpuc.deferred" end

## compiler entrypoint

export compile
Expand Down Expand Up @@ -127,33 +138,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
error("Unknown compilation output $output")
end

# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
# this could both be generalized (e.g. supporting actual function calls, instead of
# returning a function pointer), and be integrated with the nonrecursive codegen.
const deferred_codegen_jobs = Dict{Int, Any}()

# We make this function explicitly callable so that we can drive OrcJIT's
# lazy compilation from, while also enabling recursive compilation.
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
ptr
end

@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
id = length(deferred_codegen_jobs) + 1
deferred_codegen_jobs[id] = (; ft, tt)
# don't bother looking up the method instance, as we'll do so again during codegen
# using the world age of the parent.
#
# this also works around an issue on <1.10, where we don't know the world age of
# generated functions so use the current world counter, which may be too new
# for the world we're compiling for.

quote
# TODO: add an edge to this method instance to support method redefinitions
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
end
end

const __llvm_initialized = Ref(false)

@locked function emit_llvm(@nospecialize(job::CompilerJob); toplevel::Bool,
Expand Down Expand Up @@ -183,79 +167,76 @@ const __llvm_initialized = Ref(false)
entry = finish_module!(job, ir, entry)

# deferred code generation
has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
jobs = Dict{CompilerJob, String}(job => entry_fn)
if has_deferred_jobs
dyn_marker = functions(ir)["deferred_codegen"]

# iterative compilation (non-recursive)
changed = true
while changed
changed = false

# find deferred compiler
# TODO: recover this information earlier, from the Julia IR
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
for use in uses(dyn_marker)
# decode the call
call = user(use)::LLVM.CallInst
id = convert(Int, first(operands(call)))

global deferred_codegen_jobs
dyn_val = deferred_codegen_jobs[id]

# get a job in the appopriate world
dyn_job = if dyn_val isa CompilerJob
# trust that the user knows what they're doing
dyn_val
run_optimization_for_deferred = false
if haskey(functions(ir), "gpuc.lookup")
run_optimization_for_deferred = true
dyn_marker = functions(ir)["gpuc.lookup"]

# gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
# target method instance from the LLVM IR
function find_base_object(val)
while true
if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr ||
opcode(val) == LLVM.API.LLVMBitCast ||
opcode(val) == LLVM.API.LLVMAddrSpaceCast)
val = first(operands(val))
elseif val isa LLVM.IntToPtrInst ||
val isa LLVM.BitCastInst ||
val isa LLVM.AddrSpaceCastInst
val = first(operands(val))
elseif val isa LLVM.LoadInst
# In 1.11+ we no longer embed integer constants directly.
gv = first(operands(val))
if gv isa LLVM.GlobalValue
val = LLVM.initializer(gv)
continue
end
break
else
ft, tt = dyn_val
dyn_src = methodinstance(ft, tt, tls_world_age())
CompilerJob(dyn_src, job.config)
break
end

push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
end
return val
end

# compile and link
for dyn_job in keys(worklist)
# cached compilation
dyn_entry_fn = get!(jobs, dyn_job) do
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false,
parent_job=job)
dyn_entry_fn = LLVM.name(dyn_meta.entry)
merge!(compiled, dyn_meta.compiled)
@assert context(dyn_ir) == context(ir)
link!(ir, dyn_ir)
changed = true
dyn_entry_fn
end
dyn_entry = functions(ir)[dyn_entry_fn]

# insert a pointer to the function everywhere the entry is used
T_ptr = convert(LLVMType, Ptr{Cvoid})
for call in worklist[dyn_job]
@dispose builder=IRBuilder() begin
position!(builder, call)
fptr = if LLVM.version() >= v"17"
T_ptr = LLVM.PointerType()
bitcast!(builder, dyn_entry, T_ptr)
elseif VERSION >= v"1.12.0-DEV.225"
T_ptr = LLVM.PointerType(LLVM.Int8Type())
bitcast!(builder, dyn_entry, T_ptr)
else
ptrtoint!(builder, dyn_entry, T_ptr)
end
replace_uses!(call, fptr)
worklist = Dict{Any, Vector{LLVM.CallInst}}()
for use in uses(dyn_marker)
# decode the call
call = user(use)::LLVM.CallInst
dyn_mi_inst = find_base_object(operands(call)[1])
@compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job
dyn_mi = Base.unsafe_pointer_to_objref(
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
end

for dyn_mi in keys(worklist)
dyn_fn_name = compiled[dyn_mi].specfunc
dyn_fn = functions(ir)[dyn_fn_name]

# insert a pointer to the function everywhere the entry is used
T_ptr = convert(LLVMType, Ptr{Cvoid})
for call in worklist[dyn_mi]
@dispose builder=IRBuilder() begin
position!(builder, call)
fptr = if LLVM.version() >= v"17"
T_ptr = LLVM.PointerType()
bitcast!(builder, dyn_fn, T_ptr)
elseif VERSION >= v"1.12.0-DEV.225"
T_ptr = LLVM.PointerType(LLVM.Int8Type())
bitcast!(builder, dyn_fn, T_ptr)
else
ptrtoint!(builder, dyn_fn, T_ptr)
end
erase!(call)
replace_uses!(call, fptr)
end
unsafe_delete!(LLVM.parent(call), call)
end
end

# all deferred compilations should have been resolved
@compiler_assert isempty(uses(dyn_marker)) job
erase!(dyn_marker)
unsafe_delete!(ir, dyn_marker)
end

if libraries
Expand Down Expand Up @@ -285,7 +266,7 @@ const __llvm_initialized = Ref(false)
# global variables. this makes sure that the optimizer can, e.g.,
# rewrite function signatures.
if toplevel
preserved_gvs = collect(values(jobs))
preserved_gvs = [entry_fn]
for gvar in globals(ir)
if linkage(gvar) == LLVM.API.LLVMExternalLinkage
push!(preserved_gvs, LLVM.name(gvar))
Expand Down Expand Up @@ -317,7 +298,7 @@ const __llvm_initialized = Ref(false)
# deferred codegen has some special optimization requirements,
# which also need to happen _after_ regular optimization.
# XXX: make these part of the optimizer pipeline?
if has_deferred_jobs
if run_optimization_for_deferred
@dispose pb=NewPMPassBuilder() begin
add!(pb, NewPMFunctionPassManager()) do fpm
add!(fpm, InstCombinePass())
Expand Down Expand Up @@ -353,15 +334,15 @@ const __llvm_initialized = Ref(false)
# finish the module
#
# we want to finish the module after optimization, so we cannot do so
# during deferred code generation. instead, process the deferred jobs
# here.
# during deferred code generation. Instead, process the merged module
# from all the jobs here.
if toplevel
entry = finish_ir!(job, ir, entry)

for (job′, fn′) in jobs
job′ == job && continue
finish_ir!(job′, ir, functions(ir)[fn′])
end
# for (job′, fn′) in jobs
# job′ == job && continue
# finish_ir!(job′, ir, functions(ir)[fn′])
# end
end

# replace non-entry function definitions with a declaration
Expand Down
11 changes: 11 additions & 0 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ function irgen(@nospecialize(job::CompilerJob))
compiled[job.source] =
(; compiled[job.source].ci, func, specfunc)

# Earlier we sanitize global names, this invalidates the
# func, specfunc names safed in compiled. Update the names now,
# such that when when use the compiled mappings to lookup the
# llvm function for a methodinstance (deferred codegen) we have
# valid targets.
for mi in keys(compiled)
mi == job.source && continue
ci, func, specfunc = compiled[mi]
compiled[mi] = (; ci, func=safe_name(func), specfunc=safe_name(specfunc))
end

# minimal required optimization
@timeit_debug to "rewrite" begin
if job.config.kernel && needs_byval(job)
Expand Down
Loading

0 comments on commit a018c17

Please sign in to comment.