Skip to content

Commit

Permalink
Improve type stability of LayerNorm and Dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir committed Jun 24, 2022
1 parent 952c4a5 commit 54cdd12
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 25 deletions.
4 changes: 3 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ using MacroTools: @forward
using MLUtils
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions

using Zygote, ChainRulesCore
using ChainRulesCore

using Zygote
using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

Expand Down
58 changes: 41 additions & 17 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ istraining() = false

ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)

_isactive(m) = isnothing(m.active) ? istraining() : m.active
_isactive(m) = isnothing(m.active) ? istraining() : Bool(m.active)

ChainRulesCore.@non_differentiable _isactive(::Any)

_dropout_shape(s, ::Colon) = size(s)
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
Expand Down Expand Up @@ -31,26 +33,51 @@ automatically managed using the [`Dropout`](@ref) layer instead of the
The [`Dropout`](@ref) layer is what you should use in most scenarios.
"""
function dropout(rng, x, p; dims=:, active::Bool=true)
active || return x
y = dropout_mask(rng, x, p, dims=dims)
return x .* y
end
dropout(rng, x, p; dims=:, active::Bool=true) = _dropout(rng, x, p, dims, active)
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)

dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
dropout_mask(rng, x::CuArray, p; kwargs...) =
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
function _dropout_mask(rng, x, p; dims=:)
# Internal function without kwargs to keep Zygote generated code type stable
function _dropout(rng, x, p, dims, active)
mask = active ? dropout_mask(rng, x, p, dims) : nothing
return _apply_mask(x, mask)
end

function ChainRulesCore.rrule(::typeof(_dropout), rng, x, p, dims, active)
mask = active ? dropout_mask(rng, x, p, dims) : nothing
# Required because we don't always call dropout_mask
MT = Core.Compiler.return_type(dropout_mask, Tuple{typeof(rng),typeof(x),typeof(p),typeof(dims)})
project_x = ProjectTo(x)
return _apply_mask(x, mask), DropoutPullback{MT,typeof(project_x)}(mask, project_x)
end

# Also needed for type stability. Otherwise inference lifts the Union into a
# Union{pullback{Nothing}, pullback{AbstractArray}}
struct DropoutPullback{M<:AbstractArray,P<:ProjectTo{AbstractArray}}
mask::Union{Nothing,M}
project::P
end

function (pb::DropoutPullback)(dy)
dx = pb.project(_apply_mask(dy, pb.mask))
return (NoTangent(), NoTangent(), dx, NoTangent())
end

_apply_mask(x, ::Nothing) = x
_apply_mask(x, mask) = x .* mask

dropout_mask(rng::CUDA.RNG, x::CuArray, p, dims) = _dropout_mask(rng, x, p, dims)
dropout_mask(rng, x::CuArray, p, dims) =
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only supports CUDA.RNG for CuArrays."))
dropout_mask(rng, x, p, dims) = _dropout_mask(rng, x, p, dims)
function _dropout_mask(rng, x, p, dims)
realfptype = float(real(eltype(x)))
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
y .= _dropout_kernel.(y, p, 1 - p)
return y
end

# TODO move this to NNlib
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any, ::Any)

"""
Dropout(p; dims=:, rng = rng_from_array())
Expand Down Expand Up @@ -82,10 +109,7 @@ end
@functor Dropout
trainable(a::Dropout) = (;)

function (a::Dropout)(x)
_isactive(a) || return x
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
end
(a::Dropout)(x) = _dropout(a.rng, x, a.p, a.dims, _isactive(a))

testmode!(m::Dropout, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
Expand Down Expand Up @@ -172,7 +196,7 @@ LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]

@functor LayerNorm

(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
(a::LayerNorm)(x) = a.diag(_normalize(x, 1:length(a.size), a.ϵ))

function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm(", join(l.size, ", "))
Expand Down
38 changes: 32 additions & 6 deletions src/layers/stateless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,41 @@ function flatten(x::AbstractArray)
return reshape(x, :, size(x)[end])
end

# Utils for LayerNorm internals.
# Most of these are required for better performance and type stability under AD.
# In an ideal world, we'd just have normalise.

function _mean_std(x::AbstractArray, dims)
μ = mean(x, dims=dims)
σ = std(x, dims=dims, mean=μ, corrected=false)
return μ, σ
end

function ChainRulesCore.rrule(::typeof(_mean_std), x::AbstractArray, dims)
μ, mean_pullback = ChainRulesCore.rrule(mean, x, dims=dims)
σ, std_pullback = ChainRulesCore.rrule(std, x, dims=dims, mean=μ, corrected=false)
function _mean_std_pullback((dμ, dσ))
dx = ChainRulesCore.add!!(std_pullback(dσ)[2], mean_pullback(dμ)[2])
return (NoTangent(), dx, NoTangent())
end

return (μ, σ), _mean_std_pullback
end

_zscore(x, μ, σ, ϵ) = (x - μ) /+ ϵ)

# We don't define a rrule for the whole function because we want
# AD to figure out the _zscore broadcast for us.
function _normalize(x::AbstractArray, dims, ϵ)
μ, σ = _mean_std(x, dims)
return _zscore.(x, μ, σ, ϵ)
end

"""
normalise(x; dims=ndims(x), ϵ=1e-5)
Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given by `dims`.
Per default, `dims` is the last dimension.
Per default, `dims` is the last dimension.
`ϵ` is a small additive factor added to the denominator for numerical stability.
"""
@inline function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5))
μ = mean(x, dims=dims)
σ = std(x, dims=dims, mean=μ, corrected=false)
return @. (x - μ) /+ ϵ)
end
@inline normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5)) = _normalize(x, dims, ϵ)
13 changes: 12 additions & 1 deletion test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
@test cpu(m).rng === only(values(rng_kwargs))
end
end

for active in (true, false)
m = Dropout(0.5, :, active)
@inferred _, back = pullback(m, rand(10)) # _, DropoutPullback{Array{Float64}}
@inferred back(ones(10)) # Array{Float64}
end
end

@testset "AlphaDropout" begin
Expand Down Expand Up @@ -343,8 +349,13 @@ end
@test LayerNorm(2)(x) Flux.normalise(x, dims=1)
x = rand(2,3,4,5)
@test LayerNorm(2)(x) Flux.normalise(x, dims=1)

x = rand(2)
@test LayerNorm(2, tanh)(x) tanh.(Flux.normalise(x, dims=1))
m = LayerNorm(2, tanh)
@test m(x) tanh.(Flux.normalise(x, dims=1))
@inferred _, back = pullback(summ, x)
@inferred back(1.0)


x = rand(2,3,4,5)
@test LayerNorm((2,3))(x) Flux.normalise(x, dims=(1,2))
Expand Down

0 comments on commit 54cdd12

Please sign in to comment.