From 54cdd12c89b3847bf385137ce51df5f672eccac8 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Wed, 22 Jun 2022 22:14:17 -0700 Subject: [PATCH] Improve type stability of LayerNorm and Dropout --- src/Flux.jl | 4 ++- src/layers/normalise.jl | 58 +++++++++++++++++++++++++----------- src/layers/stateless.jl | 38 +++++++++++++++++++---- test/layers/normalisation.jl | 13 +++++++- 4 files changed, 88 insertions(+), 25 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 0cacbd419a..f9599569c9 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -9,7 +9,9 @@ using MacroTools: @forward using MLUtils import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions -using Zygote, ChainRulesCore +using ChainRulesCore + +using Zygote using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 4c696d916d..258131f8e7 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -2,7 +2,9 @@ istraining() = false ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),) -_isactive(m) = isnothing(m.active) ? istraining() : m.active +_isactive(m) = isnothing(m.active) ? istraining() : Bool(m.active) + +ChainRulesCore.@non_differentiable _isactive(::Any) _dropout_shape(s, ::Colon) = size(s) _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) @@ -31,18 +33,43 @@ automatically managed using the [`Dropout`](@ref) layer instead of the The [`Dropout`](@ref) layer is what you should use in most scenarios. """ -function dropout(rng, x, p; dims=:, active::Bool=true) - active || return x - y = dropout_mask(rng, x, p, dims=dims) - return x .* y -end +dropout(rng, x, p; dims=:, active::Bool=true) = _dropout(rng, x, p, dims, active) dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...) -dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) -dropout_mask(rng, x::CuArray, p; kwargs...) = - throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays.")) -dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) -function _dropout_mask(rng, x, p; dims=:) +# Internal function without kwargs to keep Zygote generated code type stable +function _dropout(rng, x, p, dims, active) + mask = active ? dropout_mask(rng, x, p, dims) : nothing + return _apply_mask(x, mask) +end + +function ChainRulesCore.rrule(::typeof(_dropout), rng, x, p, dims, active) + mask = active ? dropout_mask(rng, x, p, dims) : nothing + # Required because we don't always call dropout_mask + MT = Core.Compiler.return_type(dropout_mask, Tuple{typeof(rng),typeof(x),typeof(p),typeof(dims)}) + project_x = ProjectTo(x) + return _apply_mask(x, mask), DropoutPullback{MT,typeof(project_x)}(mask, project_x) +end + +# Also needed for type stability. Otherwise inference lifts the Union into a +# Union{pullback{Nothing}, pullback{AbstractArray}} +struct DropoutPullback{M<:AbstractArray,P<:ProjectTo{AbstractArray}} + mask::Union{Nothing,M} + project::P +end + +function (pb::DropoutPullback)(dy) + dx = pb.project(_apply_mask(dy, pb.mask)) + return (NoTangent(), NoTangent(), dx, NoTangent()) +end + +_apply_mask(x, ::Nothing) = x +_apply_mask(x, mask) = x .* mask + +dropout_mask(rng::CUDA.RNG, x::CuArray, p, dims) = _dropout_mask(rng, x, p, dims) +dropout_mask(rng, x::CuArray, p, dims) = + throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only supports CUDA.RNG for CuArrays.")) +dropout_mask(rng, x, p, dims) = _dropout_mask(rng, x, p, dims) +function _dropout_mask(rng, x, p, dims) realfptype = float(real(eltype(x))) y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) y .= _dropout_kernel.(y, p, 1 - p) @@ -50,7 +77,7 @@ function _dropout_mask(rng, x, p; dims=:) end # TODO move this to NNlib -ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any) +ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any, ::Any) """ Dropout(p; dims=:, rng = rng_from_array()) @@ -82,10 +109,7 @@ end @functor Dropout trainable(a::Dropout) = (;) -function (a::Dropout)(x) - _isactive(a) || return x - return dropout(a.rng, x, a.p; dims=a.dims, active=true) -end +(a::Dropout)(x) = _dropout(a.rng, x, a.p, a.dims, _isactive(a)) testmode!(m::Dropout, mode=true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) @@ -172,7 +196,7 @@ LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end] @functor LayerNorm -(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ)) +(a::LayerNorm)(x) = a.diag(_normalize(x, 1:length(a.size), a.ϵ)) function Base.show(io::IO, l::LayerNorm) print(io, "LayerNorm(", join(l.size, ", ")) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 34e365ae9d..d7f8758394 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -26,15 +26,41 @@ function flatten(x::AbstractArray) return reshape(x, :, size(x)[end]) end +# Utils for LayerNorm internals. +# Most of these are required for better performance and type stability under AD. +# In an ideal world, we'd just have normalise. + +function _mean_std(x::AbstractArray, dims) + μ = mean(x, dims=dims) + σ = std(x, dims=dims, mean=μ, corrected=false) + return μ, σ +end + +function ChainRulesCore.rrule(::typeof(_mean_std), x::AbstractArray, dims) + μ, mean_pullback = ChainRulesCore.rrule(mean, x, dims=dims) + σ, std_pullback = ChainRulesCore.rrule(std, x, dims=dims, mean=μ, corrected=false) + function _mean_std_pullback((dμ, dσ)) + dx = ChainRulesCore.add!!(std_pullback(dσ)[2], mean_pullback(dμ)[2]) + return (NoTangent(), dx, NoTangent()) + end + + return (μ, σ), _mean_std_pullback +end + +_zscore(x, μ, σ, ϵ) = (x - μ) / (σ + ϵ) + +# We don't define a rrule for the whole function because we want +# AD to figure out the _zscore broadcast for us. +function _normalize(x::AbstractArray, dims, ϵ) + μ, σ = _mean_std(x, dims) + return _zscore.(x, μ, σ, ϵ) +end + """ normalise(x; dims=ndims(x), ϵ=1e-5) Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given by `dims`. -Per default, `dims` is the last dimension. +Per default, `dims` is the last dimension. `ϵ` is a small additive factor added to the denominator for numerical stability. """ -@inline function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5)) - μ = mean(x, dims=dims) - σ = std(x, dims=dims, mean=μ, corrected=false) - return @. (x - μ) / (σ + ϵ) -end +@inline normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5)) = _normalize(x, dims, ϵ) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 7ae15aeff9..cda9bab972 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -73,6 +73,12 @@ evalwgrad(f, x...) = pullback(f, x...)[1] @test cpu(m).rng === only(values(rng_kwargs)) end end + + for active in (true, false) + m = Dropout(0.5, :, active) + @inferred _, back = pullback(m, rand(10)) # _, DropoutPullback{Array{Float64}} + @inferred back(ones(10)) # Array{Float64} + end end @testset "AlphaDropout" begin @@ -343,8 +349,13 @@ end @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) x = rand(2,3,4,5) @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) + x = rand(2) - @test LayerNorm(2, tanh)(x) ≈ tanh.(Flux.normalise(x, dims=1)) + m = LayerNorm(2, tanh) + @test m(x) ≈ tanh.(Flux.normalise(x, dims=1)) + @inferred _, back = pullback(sum∘m, x) + @inferred back(1.0) + x = rand(2,3,4,5) @test LayerNorm((2,3))(x) ≈ Flux.normalise(x, dims=(1,2))