Skip to content

Commit

Permalink
fix: eachslice adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 23, 2025
1 parent a98d587 commit a101ee7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
6 changes: 3 additions & 3 deletions lib/LuxLib/src/impl/bias_activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!!
∇bias_activation_no_intermediate = @closure Δ -> begin
∂x = CRC.ProjectTo(x)(∇activation(recursive_unthunk(Δ), x, σ, NotaNumber()))
∂b = CRC.@thunk CRC.ProjectTo(bias)(∇bias_add(bias, ∂x))
return ∂∅, ∂∅, ∂∅, ∂x, ∂b
return ∂∅, ∂∅, ∂∅, ∂∅, ∂x, ∂b
end
return x, ∇bias_activation_no_intermediate
end
Expand All @@ -145,7 +145,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!!
∇bias_activation_rrule = @closure Δ -> begin
∂x = CRC.ProjectTo(x)(∇activation(recursive_unthunk(Δ), y, σ, tmp))
∂b = CRC.@thunk CRC.ProjectTo(bias)(∇bias_add(bias, ∂x))
return ∂∅, ∂∅, ∂∅, ∂x, ∂b
return ∂∅, ∂∅, ∂∅, ∂∅, ∂x, ∂b
end
return y, ∇bias_activation_rrule
end
Expand All @@ -154,7 +154,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!!
cfg, bias_activation, opmode, σ, x, bias)
∇bias_activation_fallback = @closure Δ -> begin
_, _, _, ∂x, ∂b = ∇bias_activation_from_ad(Δ)
return ∂∅, ∂∅, ∂∅, CRC.ProjectTo(x)(∂x), CRC.ProjectTo(bias)(∂b)
return ∂∅, ∂∅, ∂∅, ∂∅, CRC.ProjectTo(x)(∂x), CRC.ProjectTo(bias)(∂b)
end
return res, ∇bias_activation_fallback
end
Expand Down
19 changes: 10 additions & 9 deletions src/extended_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,17 @@ function eachslice(::Type{<:AbstractDevice}, x::AbstractArray, ::Val{dims}) wher
return [selectdim(x, dims, i) for i in axes(x, dims)]
end

function ∇eachslice(Δ, x::AbstractArray, ::Val{dims}) where {dims}
idx = findfirst(Base.Fix2(isa, AbstractArray), Δ)
idx === nothing && return zero.(x)
Δ = similar(x)
fill!(Δ, false)
for i in axes(x, dims)
Δᵢ = selectdim(Δ, dims, i)
copyto!(Δᵢ, Δ[i])
function ∇eachslice(Δs, x::AbstractArray, ::Val{dims}) where {dims}
idx = findfirst(Base.Fix2(isa, AbstractArray), Δs)
idx === nothing && return CRC.ZeroTangent()
return CRC.@thunk begin
Δ = similar(x)
fill!(Δ, false)
for i in axes(x, dims)
copyto!(selectdim(Δ, dims, i), Δs[i])
end
return CRC.ProjectTo(x)(Δ)
end
return CRC.ProjectTo(x)(Δ)
end

function CRC.rrule(::typeof(eachslice), x::AbstractArray, d::Val{dims}) where {dims}
Expand Down

0 comments on commit a101ee7

Please sign in to comment.