Skip to content

Commit

Permalink
Try #1837:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Feb 5, 2022
2 parents cce7ad0 + 7e4480b commit e1278a9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ end
@functor Dense

function (a::Dense)(x::AbstractVecOrMat)
W, b, σ = a.weight, a.bias, a.σ
W, b= a.weight, a.bias
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
return σ.(W*x .+ b)
end

Expand Down
12 changes: 8 additions & 4 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ end
@functor Conv

function (c::Conv)(x::AbstractArray)
σ, b = c.σ, reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1)
b = reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1)
σ = NNlib.fast_act(c.σ, x)
cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
σ.(conv(x, c.weight, cdims) .+ b)
end
Expand Down Expand Up @@ -278,7 +279,8 @@ end
@nograd conv_transpose_dims

function (c::ConvTranspose)(x::AbstractArray)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_transpose_dims(c, x)
σ.(∇conv_data(x, c.weight, cdims) .+ b)
end
Expand Down Expand Up @@ -371,7 +373,8 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])

function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ = NNlib.fast_act(c.σ, x)
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
end
Expand Down Expand Up @@ -450,7 +453,8 @@ function crosscor(x, w, ddims::DenseConvDims)
end

function (c::CrossCor)(x::AbstractArray)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ = NNlib.fast_act(c.σ, x)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(crosscor(x, c.weight, cdims) .+ b)
end
Expand Down
11 changes: 6 additions & 5 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zero
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))

function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
Wi, Wh, b = m.Wi, m.Wh, m.b
σ = NNlib.fast_act(m.σ, x)
h = σ.(Wi*x .+ Wh*h .+ b)
return h, reshape_cell_output(h, x)
end
Expand Down Expand Up @@ -224,8 +225,8 @@ function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{Abstr
b, o = m.b, size(h, 1)
g = m.Wi*x .+ m.Wh*h .+ b
input, forget, cell, output = multigate(g, o, Val(4))
c′ = @. σ(forget) * c + σ(input) * tanh(cell)
h′ = @. σ(output) * tanh(c′)
c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
h′ = @. sigmoid_fast(output) * tanh_fast(c′)
return (h′, c′), reshape_cell_output(h′, x)
end

Expand Down Expand Up @@ -309,7 +310,7 @@ function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},O
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
r, z = _gru_output(gxs, ghs, bs)
= @. tanh(gxs[3] + r * ghs[3] + bs[3])
= @. tanh_fast(gxs[3] + r * ghs[3] + bs[3])
h′ = @. (1 - z) *+ z * h
return h′, reshape_cell_output(h′, x)
end
Expand Down Expand Up @@ -387,7 +388,7 @@ function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T}
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
r, z = _gru_output(gxs, ghs, bs)
= tanh.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
= tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
h′ = @. (1 - z) *+ z * h
return h′, reshape_cell_output(h′, x)
end
Expand Down

0 comments on commit e1278a9

Please sign in to comment.