Skip to content

Commit

Permalink
Update test/Project.toml
Browse files Browse the repository at this point in the history
Change WeightDecay.gamma => lambda and put compat bound on Optimisers
  • Loading branch information
DrChainsaw committed Jun 29, 2024
1 parent b6361b2 commit 2471c5d
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ IterTools = "1"
MemPool = "0.3"
NaiveNASflux = "2.0.10"
NaiveNASlib = "2.0.11"
Optimisers = "0.2, 0.3"
Optimisers = "0.3.2"
PackageExtensionCompat = "1"
PrecompileTools = "1"
Reexport = "0.2.0, 1"
Expand Down
2 changes: 1 addition & 1 deletion src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ mergeopts(os::Tuple{}) = os
mergeopts(::Type{T}, os::T...) where T = mergeopts(os...)
mergeopts(os::Optimisers.AbstractRule...) = first(@set os[1].eta = prod(learningrate, os))
mergeopts(os::ShieldedOpt{T}...) where T = ShieldedOpt(only(mergeopts(map(o -> o.rule, os))))
mergeopts(os::WeightDecay...) = WeightDecay(prod(o -> o.gamma, os))
mergeopts(os::WeightDecay...) = WeightDecay(prod(o -> o.lambda, os))

"""
optmap(fopt, x, felse=identity)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
CUDA = "4"
CUDA = "5"
cuDNN = "1"
Documenter = "0.27"
Flux = "0.14"
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ import CUDA
@info "Testing visualization"
include("visualization/callbacks.jl")

if VERSION === v"1.10.0"
if VERSION === v"1.10.4"
@info "Testing README examples"
include("examples.jl")
else
@warn "README examples will only be tested in julia version 1.9.2 due to rng dependency. Skipping..."
@warn "README examples will only be tested in julia version 1.10.4 due to rng dependency. Skipping..."
end

@info "Testing AutoFlux"
Expand Down
2 changes: 1 addition & 1 deletion test/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ end
@test learningrate(dm) 0.08f0

wd = mergeopts(WeightDecay(0.1f0), WeightDecay(2f0), WeightDecay(0.4f0))
@test wd.gamma 0.08f0
@test wd.lambda 0.08f0

dd = mergeopts(Momentum, Descent(0.1f0), Descent(2f0), Descent(0.4f0))
@test typeof.(dd) == (Descent{Float32}, Descent{Float32}, Descent{Float32})
Expand Down

0 comments on commit 2471c5d

Please sign in to comment.