Skip to content

Commit

Permalink
fix and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir committed Jun 24, 2022
1 parent 54cdd12 commit 9259e4a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
11 changes: 8 additions & 3 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ istraining() = false

ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)

_isactive(m) = isnothing(m.active) ? istraining() : Bool(m.active)
_isactive(m) = Bool(something(m.active, istraining()))

ChainRulesCore.@non_differentiable _isactive(::Any)
# Avoids instabilities from differentiating through getproperty(m, :active)
function ChainRulesCore.rrule(::typeof(_isactive), m)
training, _ = rrule(istraining)
_isactive_pullback(_) = (NoTangent(), NoTangent())
return Bool(something(m.active, training)), _isactive_pullback
end

_dropout_shape(s, ::Colon) = size(s)
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
Expand Down Expand Up @@ -59,7 +64,7 @@ end

function (pb::DropoutPullback)(dy)
dx = pb.project(_apply_mask(dy, pb.mask))
return (NoTangent(), NoTangent(), dx, NoTangent())
return (NoTangent(), NoTangent(), dx, NoTangent(), NoTangent(), NoTangent())
end

_apply_mask(x, ::Nothing) = x
Expand Down
10 changes: 5 additions & 5 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
@test cpu(m).rng === only(values(rng_kwargs))
end
end

for active in (true, false)
m = Dropout(0.5, :, active)
@inferred _, back = pullback(m, rand(10)) # _, DropoutPullback{Array{Float64}}
_, back = @inferred pullback(m, rand(10)) # _, DropoutPullback{Array{Float64}}
@inferred back(ones(10)) # Array{Float64}
end
end
Expand Down Expand Up @@ -353,9 +353,9 @@ end
x = rand(2)
m = LayerNorm(2, tanh)
@test m(x) tanh.(Flux.normalise(x, dims=1))
@inferred _, back = pullback(summ, x)
@inferred back(1.0)

_, back = @inferred pullback(|>, x, m)
# TODO needs https://github.com/FluxML/Zygote.jl/pull/1248
# @inferred back(1.0)

x = rand(2,3,4,5)
@test LayerNorm((2,3))(x) Flux.normalise(x, dims=(1,2))
Expand Down

0 comments on commit 9259e4a

Please sign in to comment.