From 9bc217cfe409ae4dbf7fcf72b56483f507059537 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Tue, 17 Sep 2024 23:17:09 -0700 Subject: [PATCH] Make recursive_add/accumulate more recursive --- src/compiler.jl | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 1d21fb99a1..cf8a5406bf 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6562,7 +6562,13 @@ end Base.@_inline_meta prev = getfield(x, i) next = getfield(y, i) - recursive_add(prev, next, f, forcelhs) + ST = Core.Typeof(prev) + if !mutable_register(ST) + recursive_add(prev, next, f, forcelhs) + elseif !(ST <: Integer) + recursive_accumulate(prev, next, f) + prev + end end) end @@ -6591,10 +6597,11 @@ end # Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) @inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F=identity) where {T, F} - if !mutable_register(T) - for I in eachindex(x) - prev = x[I] + for I in eachindex(x, y) + if !mutable_register(T) @inbounds x[I] = recursive_add(x[I], (@inbounds y[I]), f, mutable_register) + elseif !(T <: Integer) + recursive_accumulate((@inbounds x[I]), (@inbounds y[I]), f) end end end @@ -6602,7 +6609,7 @@ end # Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) @inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F=identity) where {F} - recursive_accumulate(x.contents, y.contents, seen, f) + recursive_accumulate(x.contents, y.contents, f) end @inline function recursive_accumulate(x::T, y::T, f::F=identity) where {T, F} @@ -6613,12 +6620,14 @@ end for i in 1:nf if isdefined(x, i) xi = getfield(x, i) + yi = getfield(y, i) ST = Core.Typeof(xi) if !mutable_register(ST) @assert ismutable(x) - yi = getfield(y, i) nexti = recursive_add(xi, yi, f, mutable_register) setfield!(x, i, nexti) + elseif !(ST <: Integer) + recursive_accumulate(xi, yi, f) end end end