diff --git a/src/backends/cpu.jl b/src/backends/cpu.jl index cf449cfd..0977eecb 100644 --- a/src/backends/cpu.jl +++ b/src/backends/cpu.jl @@ -132,28 +132,13 @@ struct ScratchArray{N, D} end @inline function Cassette.overdub(ctx::CPUCtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims} - return ScratchArray{length(Dims)}(MArray{__size((Dims..., __groupsize(ctx.metadata))), T}(undef)) + return ScratchArray{length(Dims)}(MArray{__size((Dims..., __groupsize(ctx.metadata)...)), T}(undef)) end -Base.@propagate_inbounds function Cassette.overdub(ctx::CPUCtx, ::typeof(Base.getindex), A::ScratchArray{N}, I...) where N - nI = ntuple(Val(N+1)) do i - if i == N+1 - __groupindex(ctx.metadata) - else - I[i] - end - end - - return A.data[nI...] +Base.@propagate_inbounds function Cassette.overdub(ctx::CPUCtx, ::typeof(Base.getindex), A::ScratchArray, I...) + return A.data[I...] end -Base.@propagate_inbounds function Cassette.overdub(ctx::CPUCtx, ::typeof(Base.setindex!), A::ScratchArray{N}, val, I...) where N - nI = ntuple(Val(N+1)) do i - if i == N+1 - __groupindex(ctx.metadata) - else - I[i] - end - end - A.data[nI...] = val +Base.@propagate_inbounds function Cassette.overdub(ctx::CPUCtx, ::typeof(Base.setindex!), A::ScratchArray, val, I...) + A.data[I...] = val end diff --git a/src/compiler.jl b/src/compiler.jl index 5852892b..89120857 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -22,6 +22,7 @@ end @inline __iterspace(cm::CompilerMetadata) = cm.iterspace @inline __groupindex(cm::CompilerMetadata) = cm.groupindex +@inline __groupsize(cm::CompilerMetadata) = size(workitems(__iterspace(cm))) @inline __dynamic_checkbounds(::CompilerMetadata{NDRange, CB}) where {NDRange, CB} = CB @inline __ndrange(cm::CompilerMetadata{NDRange}) where {NDRange<:StaticSize} = CartesianIndices(get(NDRange)) @inline __ndrange(cm::CompilerMetadata{NDRange}) where {NDRange<:DynamicSize} = cm.ndrange @@ -31,7 +32,7 @@ include("compiler/pass.jl") function generate_overdubs(Ctx) @eval begin - @inline Cassette.overdub(ctx::$Ctx, ::typeof(groupsize)) = size(workitems(__iterspace(ctx.metadata))) + @inline Cassette.overdub(ctx::$Ctx, ::typeof(groupsize)) = __groupsize(ctx.metadata) @inline Cassette.overdub(ctx::$Ctx, ::typeof(__workitems_iterspace)) = workitems(__iterspace(ctx.metadata)) ### diff --git a/src/macros.jl b/src/macros.jl index 604a4333..7ecf05a8 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -1,4 +1,4 @@ -import MacroTools: splitdef, combinedef, isexpr +import MacroTools: splitdef, combinedef, isexpr, postwalk # XXX: Proper errors function __kernel(expr) @@ -104,6 +104,7 @@ struct WorkgroupLoop indicies :: Vector{Any} stmts :: Vector{Any} allocations :: Vector{Any} + private :: Vector{Any} end @@ -116,12 +117,13 @@ function split(stmts) current = Any[] indicies = Any[] allocations = Any[] + private = Any[] loops = WorkgroupLoop[] for stmt in stmts.args if isexpr(stmt, :macrocall) if stmt.args[1] === Symbol("@synchronize") - loop = WorkgroupLoop(deepcopy(indicies), current, allocations) + loop = WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private)) push!(loops, loop) allocations = Any[] current = Any[] @@ -137,10 +139,13 @@ function split(stmts) push!(indicies, stmt) continue elseif callee === Symbol("@localmem") || - callee === Symbol("@private") || callee === Symbol("@uniform") push!(allocations, stmt) continue + elseif callee === Symbol("@private") + push!(allocations, stmt) + push!(private, stmt.args[1]) + continue end end end @@ -150,7 +155,7 @@ function split(stmts) # everything since the last `@synchronize` if !isempty(current) - push!(loops, WorkgroupLoop(deepcopy(indicies), current, allocations)) + push!(loops, WorkgroupLoop(deepcopy(indicies), current, allocations, deepcopy(private))) end return loops end @@ -163,12 +168,21 @@ function emit(loop) rhs = stmt.args[2] push!(rhs.args, idx) end + body = Expr(:block, loop.stmts...) + body = postwalk(body) do expr + if @capture(expr, A_[i__]) + if A in loop.private + return :($A[$(i...), $(idx).I...]) + end + end + return expr + end quote $(loop.allocations...) for $idx in $__workitems_iterspace() $__validindex($idx) || continue $(loop.indicies...) - $(loop.stmts...) + $(body) end end end diff --git a/test/private.jl b/test/private.jl new file mode 100644 index 00000000..71056879 --- /dev/null +++ b/test/private.jl @@ -0,0 +1,33 @@ +using KernelAbstractions +using Test +using CUDAapi +if has_cuda_gpu() + using CuArrays + CuArrays.allowscalar(false) +end + +@kernel function private(A) + N = prod(groupsize()) + I = @index(Global, Linear) + i = @index(Local, Linear) + priv = @private Int (1,) + priv[1] = N - i + 1 + @synchronize + A[I] = priv[1] +end + +function harness(backend, ArrayT) + A = ArrayT{Int}(undef, 64) + wait(private(backend, 16)(A, ndrange=size(A))) + @test all(A[1:16] .== 16:-1:1) + @test all(A[17:32] .== 16:-1:1) + @test all(A[33:48] .== 16:-1:1) + @test all(A[49:64] .== 16:-1:1) +end + +@testset "kernels" begin + harness(CPU(), Array) + if has_cuda_gpu() + harness(CUDA(), CuArray) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index c7cadcfd..fac14c87 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,10 @@ end include("localmem.jl") end +@testset "Private" begin + include("private.jl") +end + @testset "Unroll" begin include("unroll.jl") end