Skip to content

Commit

Permalink
fix: eachslice adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 23, 2025
1 parent a98d587 commit 563d8c3
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/extended_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,17 @@ function eachslice(::Type{<:AbstractDevice}, x::AbstractArray, ::Val{dims}) wher
return [selectdim(x, dims, i) for i in axes(x, dims)]
end

function ∇eachslice(Δ, x::AbstractArray, ::Val{dims}) where {dims}
idx = findfirst(Base.Fix2(isa, AbstractArray), Δ)
idx === nothing && return zero.(x)
Δ = similar(x)
fill!(Δ, false)
for i in axes(x, dims)
Δᵢ = selectdim(Δ, dims, i)
copyto!(Δᵢ, Δ[i])
function ∇eachslice(Δs, x::AbstractArray, ::Val{dims}) where {dims}
idx = findfirst(Base.Fix2(isa, AbstractArray), Δs)
idx === nothing && return CRC.ZeroTangent()
return CRC.@thunk begin
Δ = similar(x)
fill!(Δ, false)
for i in axes(x, dims)
copyto!(selectdim(Δ, dims, i), Δs[i])
end
return CRC.ProjectTo(x)(Δ)
end
return CRC.ProjectTo(x)(Δ)
end

function CRC.rrule(::typeof(eachslice), x::AbstractArray, d::Val{dims}) where {dims}
Expand Down

0 comments on commit 563d8c3

Please sign in to comment.