Skip to content

Commit

Permalink
Fix private memory on the CPU
Browse files Browse the repository at this point in the history
Co-authored-by: Valentin Churavy <[email protected]>
  • Loading branch information
mwarusz and vchuravy committed Feb 27, 2020
1 parent c4d2487 commit 99dbaa1
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 26 deletions.
25 changes: 5 additions & 20 deletions src/backends/cpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,28 +132,13 @@ struct ScratchArray{N, D}
end

@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims}
return ScratchArray{length(Dims)}(MArray{__size((Dims..., __groupsize(ctx.metadata))), T}(undef))
return ScratchArray{length(Dims)}(MArray{__size((Dims..., __groupsize(ctx.metadata)...)), T}(undef))
end

Base.@propagate_inbounds function Cassette.overdub(ctx::CPUCtx, ::typeof(Base.getindex), A::ScratchArray{N}, I...) where N
nI = ntuple(Val(N+1)) do i
if i == N+1
__groupindex(ctx.metadata)
else
I[i]
end
end

return A.data[nI...]
Base.@propagate_inbounds function Cassette.overdub(ctx::CPUCtx, ::typeof(Base.getindex), A::ScratchArray, I...)
return A.data[I...]
end

Base.@propagate_inbounds function Cassette.overdub(ctx::CPUCtx, ::typeof(Base.setindex!), A::ScratchArray{N}, val, I...) where N
nI = ntuple(Val(N+1)) do i
if i == N+1
__groupindex(ctx.metadata)
else
I[i]
end
end
A.data[nI...] = val
Base.@propagate_inbounds function Cassette.overdub(ctx::CPUCtx, ::typeof(Base.setindex!), A::ScratchArray, val, I...)
A.data[I...] = val
end
3 changes: 2 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ end

@inline __iterspace(cm::CompilerMetadata) = cm.iterspace
@inline __groupindex(cm::CompilerMetadata) = cm.groupindex
@inline __groupsize(cm::CompilerMetadata) = size(workitems(__iterspace(cm)))
@inline __dynamic_checkbounds(::CompilerMetadata{NDRange, CB}) where {NDRange, CB} = CB
@inline __ndrange(cm::CompilerMetadata{NDRange}) where {NDRange<:StaticSize} = CartesianIndices(get(NDRange))
@inline __ndrange(cm::CompilerMetadata{NDRange}) where {NDRange<:DynamicSize} = cm.ndrange
Expand All @@ -31,7 +32,7 @@ include("compiler/pass.jl")

function generate_overdubs(Ctx)
@eval begin
@inline Cassette.overdub(ctx::$Ctx, ::typeof(groupsize)) = size(workitems(__iterspace(ctx.metadata)))
@inline Cassette.overdub(ctx::$Ctx, ::typeof(groupsize)) = __groupsize(ctx.metadata)
@inline Cassette.overdub(ctx::$Ctx, ::typeof(__workitems_iterspace)) = workitems(__iterspace(ctx.metadata))

###
Expand Down
24 changes: 19 additions & 5 deletions src/macros.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import MacroTools: splitdef, combinedef, isexpr
import MacroTools: splitdef, combinedef, isexpr, postwalk

# XXX: Proper errors
function __kernel(expr)
Expand Down Expand Up @@ -104,6 +104,7 @@ struct WorkgroupLoop
indicies :: Vector{Any}
stmts :: Vector{Any}
allocations :: Vector{Any}
private :: Vector{Any}
end


Expand All @@ -116,12 +117,13 @@ function split(stmts)
current = Any[]
indicies = Any[]
allocations = Any[]
private = Any[]

loops = WorkgroupLoop[]
for stmt in stmts.args
if isexpr(stmt, :macrocall)
if stmt.args[1] === Symbol("@synchronize")
loop = WorkgroupLoop(deepcopy(indicies), current, allocations)
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private))
push!(loops, loop)
allocations = Any[]
current = Any[]
Expand All @@ -137,10 +139,13 @@ function split(stmts)
push!(indicies, stmt)
continue
elseif callee === Symbol("@localmem") ||
callee === Symbol("@private") ||
callee === Symbol("@uniform")
push!(allocations, stmt)
continue
elseif callee === Symbol("@private")
push!(allocations, stmt)
push!(private, stmt.args[1])
continue
end
end
end
Expand All @@ -150,7 +155,7 @@ function split(stmts)

# everything since the last `@synchronize`
if !isempty(current)
push!(loops, WorkgroupLoop(deepcopy(indicies), current, allocations))
push!(loops, WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private)))
end
return loops
end
Expand All @@ -163,12 +168,21 @@ function emit(loop)
rhs = stmt.args[2]
push!(rhs.args, idx)
end
body = Expr(:block, loop.stmts...)
body = postwalk(body) do expr
if @capture(expr, A_[i__])
if A in loop.private
return :($A[$(i...), $(idx).I...])
end
end
return expr
end
quote
$(loop.allocations...)
for $idx in $__workitems_iterspace()
$__validindex($idx) || continue
$(loop.indicies...)
$(loop.stmts...)
$(body)
end
end
end
33 changes: 33 additions & 0 deletions test/private.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using KernelAbstractions
using Test
using CUDAapi
if has_cuda_gpu()
using CuArrays
CuArrays.allowscalar(false)
end

@kernel function private(A)
N = prod(groupsize())
I = @index(Global, Linear)
i = @index(Local, Linear)
priv = @private Int (1,)
priv[1] = N - i + 1
@synchronize
A[I] = priv[1]
end

function harness(backend, ArrayT)
A = ArrayT{Int}(undef, 64)
wait(private(backend, 16)(A, ndrange=size(A)))
@test all(A[1:16] .== 16:-1:1)
@test all(A[17:32] .== 16:-1:1)
@test all(A[33:48] .== 16:-1:1)
@test all(A[49:64] .== 16:-1:1)
end

@testset "kernels" begin
harness(CPU(), Array)
if has_cuda_gpu()
harness(CUDA(), CuArray)
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ end
include("localmem.jl")
end

@testset "Private" begin
include("private.jl")
end

@testset "Unroll" begin
include("unroll.jl")
end
Expand Down

0 comments on commit 99dbaa1

Please sign in to comment.