Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching allocator interface #576

Merged
merged 28 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

Expand All @@ -23,6 +24,7 @@ LinearAlgebra = "1"
Printf = "1"
Random = "1"
Reexport = "1"
ScopedValues = "1"
Serialization = "1"
Statistics = "1"
julia = "1.10"
1 change: 1 addition & 0 deletions docs/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Manifest.toml
build
site
Manifest.toml
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ function main()
"Test suite" => "testsuite.md",
],
doctest = true,
warnonly = [:missing_docs],
)

deploydocs(
Expand Down
11 changes: 9 additions & 2 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Device functionality is then handled by [KernelAbstractions.jl](https://github.c

You should provide an array type that builds on the `AbstractGPUArray` supertype, such as:

```
```julia
mutable struct CustomArray{T, N} <: AbstractGPUArray{T, N}
data::DataRef{Vector{UInt8}}
offset::Int
Expand All @@ -23,10 +23,17 @@ end
This will allow your defined type (in this case `JLArray`) to use the GPUArrays interface where available.
To be able to actually use the functionality that is defined for `AbstractGPUArray`s, you need to define the backend, like so:

```
```julia
import KernelAbstractions: Backend
struct CustomBackend <: KernelAbstractions.GPU
KernelAbstractions.get_backend(a::CA) where CA <: CustomArray = CustomBackend()
```

There are numerous examples of potential interfaces for GPUArrays, such as with [JLArrays](https://github.com/JuliaGPU/GPUArrays.jl/blob/master/lib/JLArrays/src/JLArrays.jl), [CuArrays](https://github.com/JuliaGPU/CUDA.jl/blob/master/src/gpuarrays.jl), and [ROCArrays](https://github.com/JuliaGPU/AMDGPU.jl/blob/master/src/gpuarrays.jl).

## Caching Allocator

```@docs
GPUArrays.@cached
GPUArrays.@uncached
```
17 changes: 11 additions & 6 deletions lib/JLArrays/src/JLArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,16 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
check_eltype(T)
maxsize = prod(dims) * sizeof(T)
data = Vector{UInt8}(undef, maxsize)
ref = DataRef(data) do data
resize!(data, 0)
end
obj = new{T,N}(ref, 0, dims)
finalizer(unsafe_free!, obj)

GPUArrays.cached_alloc((JLArray, T, dims)) do
data = Vector{UInt8}(undef, maxsize)
maleadt marked this conversation as resolved.
Show resolved Hide resolved
maleadt marked this conversation as resolved.
Show resolved Hide resolved
ref = DataRef(data) do data
resize!(data, 0)
end
obj = new{T, N}(ref, 0, dims)
finalizer(unsafe_free!, obj)
return obj
end::JLArray{T,N}
end
maleadt marked this conversation as resolved.
Show resolved Hide resolved

# low-level constructor for wrapping existing data
Expand All @@ -102,6 +106,7 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
check_eltype(T)
obj = new{T,N}(ref, offset, dims)
finalizer(unsafe_free!, obj)
return obj
end
end

Expand Down
1 change: 1 addition & 0 deletions src/GPUArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ include("host/random.jl")
include("host/quirks.jl")
include("host/uniformscaling.jl")
include("host/statistics.jl")
include("host/alloc_cache.jl")


end # module
157 changes: 157 additions & 0 deletions src/host/alloc_cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
using ..GPUArrays

@static if VERSION < v"1.11"
using ScopedValues
else
using Base.ScopedValues
end

mutable struct AllocCache{T <: AbstractGPUArray}
lock::ReentrantLock
busy::Dict{UInt64, Vector{T}} # hash(key) => GPUArray[]
free::Dict{UInt64, Vector{T}}
maleadt marked this conversation as resolved.
Show resolved Hide resolved

function AllocCache(::Type{T}) where {T <: AbstractGPUArray}
cache = new{T}(
ReentrantLock(),
Dict{UInt64, Vector{T}}(),
Dict{UInt64, Vector{T}}()
)
pxl-th marked this conversation as resolved.
Show resolved Hide resolved
maleadt marked this conversation as resolved.
Show resolved Hide resolved
return finalizer(unsafe_free!, cache)
end
end

function get_pool!(cache::AllocCache{T}, pool::Symbol, uid::UInt64) where {T <: AbstractGPUArray}
pool = getproperty(cache, pool)
uid_pool = get(pool, uid, nothing)
if uid_pool ≡ nothing
uid_pool = Base.@lock cache.lock pool[uid] = T[]
end
return uid_pool
end

function cached_alloc(f, key)
cache = ALLOC_CACHE[]
if cache === nothing
return f()
end

x = nothing
uid = hash(key)

busy_pool = get_pool!(cache, :busy, uid)
free_pool = get_pool!(cache, :free, uid)
isempty(free_pool) && (x = f())

while !isempty(free_pool) && x ≡ nothing
tmp = Base.@lock cache.lock pop!(free_pool)
# Array was manually freed via `unsafe_free!`.
GPUArrays.storage(tmp).freed && continue
x = tmp
end

x ≡ nothing && (x = f())
Base.@lock cache.lock push!(busy_pool, x)
return x
end

function free_busy!(cache::AllocCache)
for uid in cache.busy.keys
busy_pool = get_pool!(cache, :busy, uid)
isempty(busy_pool) && continue

Base.@lock cache.lock begin
free_pool = get_pool!(cache, :free, uid)
append!(free_pool, busy_pool)
empty!(busy_pool)
end
end
return
end

function unsafe_free!(cache::AllocCache)
Base.@lock cache.lock begin
for (_, pool) in cache.busy
isempty(pool) || error(
"Invalidating allocations cache that's currently in use. " *
pxl-th marked this conversation as resolved.
Show resolved Hide resolved
"Invalidating inside `@cached` is not allowed."
)
end
for (_, pool) in cache.free
map(unsafe_free!, pool)
end
empty!(cache.free)
end
return
end

function Base.sizeof(cache::AllocCache)
sz = UInt64(0)
Base.@lock cache.lock begin
for kind in (cache.free, cache.busy), (_, pool) in kind
sz += sum(sizeof, pool; init = UInt64(0))
end
end
return sz
end

const ALLOC_CACHE = ScopedValue{Union{Nothing, AllocCache}}(nothing)

"""
@cached(cache, expr)

Evaluate expression `expr` using allocations cache `cache`.

When gpu allocation is requested during execution of `expr`,
it will first check if there's "free" cache instead of performing an actual allocation.
If no "free" allocation exists, an actual allocation is performed.
Before returning allocation to the user, it is marked as busy and
will not be used by allocation in the scope defined by `@cached`.

**After** the execution of `expr` all "busy" allocations are marked as "free"
thus they can be re-used next time the program enters this scope.

This is useful to apply in a repeating block of code to avoid relying on
GC to free gpu memory in time.

# Example

In the following example, each iteration of the for-loop requires `8 GiB` of gpu memory.
Without caching allocator GC wouldn't be able to free arrays in time
resulting in higher memory usage.
With caching allocator, memory usage stays at exactly `8 GiB`.

```julia
cache = GPUArrays.AllocCache(CuArray)
n = 1024^3
for i in 1:1000
GPUArrays.@cached cache begin
sin.(CUDA.rand(Float32, n))
end
end
# To free immediately.
# Otherwise, it will be freed when collected by GC.
GPUArrays.unsafe_free!(cache)
```

See [`@uncached`](@ref).
"""
macro cached(cache, expr)
return quote
res = @with $(esc(ALLOC_CACHE)) => $(esc(cache)) $(esc(expr))
free_busy!($(esc(cache)))
res
end
end

"""
uncached(expr)

Evaluate expression `expr` without using allocations cache.
This is useful to call from within `@cached` to avoid caching some allocations.
"""
macro uncached(expr)
return quote
@with $(esc(ALLOC_CACHE)) => nothing $(esc(expr))
end
end
1 change: 1 addition & 0 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ include("testsuite/math.jl")
include("testsuite/random.jl")
include("testsuite/uniformscaling.jl")
include("testsuite/statistics.jl")
include("testsuite/alloc_cache.jl")

"""
Runs the entire GPUArrays test suite on array type `AT`
Expand Down
43 changes: 43 additions & 0 deletions test/testsuite/alloc_cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
@testsuite "alloc cache" (AT, eltypes) -> begin
if AT <: AbstractGPUArray
cache = GPUArrays.AllocCache(AT)

T, dims = Float32, (1, 2, 3)
GPUArrays.@cached cache begin
x1 = AT(zeros(T, dims))
end
@test sizeof(cache) == sizeof(T) * prod(dims)
key = first(keys(cache.free))
@test length(cache.free[key]) == 1
@test length(cache.busy[key]) == 0
@test x1 === cache.free[key][1]

# Second allocation hits cache.
GPUArrays.@cached cache begin
x2 = AT(zeros(T, dims))
# Does not hit the cache.
GPUArrays.@uncached x_free = AT(zeros(T, dims))
end
@test sizeof(cache) == sizeof(T) * prod(dims)
key = first(keys(cache.free))
@test length(cache.free[key]) == 1
@test length(cache.busy[key]) == 0
@test x2 === cache.free[key][1]
@test x_free !== x2

# Third allocation is of different shape - allocates.
dims = (2, 2)
GPUArrays.@cached cache begin
x3 = AT(zeros(T, dims))
end
_keys = collect(keys(cache.free))
key2 = _keys[findfirst(i -> i != key, _keys)]
@test length(cache.free[key]) == 1
@test length(cache.free[key2]) == 1
@test x3 === cache.free[key2][1]

# Freeing all memory held by cache.
GPUArrays.unsafe_free!(cache)
@test sizeof(cache) == 0
end
end
Loading