From 9960b5225c18a997711219c7b2a7040b9952aa7e Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 9 Jan 2025 11:47:43 +0100 Subject: [PATCH] Simplify back-end interface. --- lib/JLArrays/src/JLArrays.jl | 15 +++++---------- src/host/alloc_cache.jl | 11 ++++++++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index 92ee7de0..18be1889 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -89,21 +89,15 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} check_eltype(T) maxsize = prod(dims) * sizeof(T) - function _alloc_f() + GPUArrays.cached_alloc((JLArray, T, dims)) do data = Vector{UInt8}(undef, maxsize) ref = DataRef(data) do data resize!(data, 0) end obj = new{T, N}(ref, 0, dims) - return finalizer(unsafe_free!, obj) - end - - cache = GPUArrays.ALLOC_CACHE[] - return if cache ≡ nothing - _alloc_f() - else - GPUArrays.alloc!(_alloc_f, cache, (JLArray, T, dims))::JLArray{T, N} - end + finalizer(unsafe_free!, obj) + return obj + end::JLArray{T,N} end # low-level constructor for wrapping existing data @@ -112,6 +106,7 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} check_eltype(T) obj = new{T,N}(ref, offset, dims) finalizer(unsafe_free!, obj) + return obj end end diff --git a/src/host/alloc_cache.jl b/src/host/alloc_cache.jl index 899f2e5d..6c9a1200 100644 --- a/src/host/alloc_cache.jl +++ b/src/host/alloc_cache.jl @@ -30,13 +30,18 @@ function get_pool!(cache::AllocCache{T}, pool::Symbol, uid::UInt64) where {T <: return uid_pool end -function alloc!(alloc_f, cache::AllocCache, key) +function cached_alloc(f, key) + cache = ALLOC_CACHE[] + if cache === nothing + return f() + end + x = nothing uid = hash(key) busy_pool = get_pool!(cache, :busy, uid) free_pool = get_pool!(cache, :free, uid) - isempty(free_pool) && (x = alloc_f()) + isempty(free_pool) && (x = f()) while !isempty(free_pool) && x ≡ nothing tmp = Base.@lock cache.lock pop!(free_pool) @@ -45,7 +50,7 @@ function alloc!(alloc_f, cache::AllocCache, key) x = tmp end - x ≡ nothing && (x = alloc_f()) + x ≡ nothing && (x = f()) Base.@lock cache.lock push!(busy_pool, x) return x end