diff --git a/Project.toml b/Project.toml index 8f72164d..c23df4eb 100644 --- a/Project.toml +++ b/Project.toml @@ -28,14 +28,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NaiveGAfluxCUDAExt = "CUDA" [compat] -CUDA = "3, 4" +CUDA = "3, 4, 5" Flux = "0.13.4, 0.14" Functors = "0.2, 0.3, 0.4" IterTools = "1" MemPool = "0.3" NaiveNASflux = "2.0.10" NaiveNASlib = "2.0.11" -Optimisers = "0.2, 0.3" +Optimisers = "0.3.2" PackageExtensionCompat = "1" PrecompileTools = "1" Reexport = "0.2.0, 1" diff --git a/src/util.jl b/src/util.jl index 4b6c2311..ec75146b 100644 --- a/src/util.jl +++ b/src/util.jl @@ -256,7 +256,7 @@ mergeopts(os::Tuple{}) = os mergeopts(::Type{T}, os::T...) where T = mergeopts(os...) mergeopts(os::Optimisers.AbstractRule...) = first(@set os[1].eta = prod(learningrate, os)) mergeopts(os::ShieldedOpt{T}...) where T = ShieldedOpt(only(mergeopts(map(o -> o.rule, os)))) -mergeopts(os::WeightDecay...) = WeightDecay(prod(o -> o.gamma, os)) +mergeopts(os::WeightDecay...) = WeightDecay(prod(o -> o.lambda, os)) """ optmap(fopt, x, felse=identity) diff --git a/test/Project.toml b/test/Project.toml index f8a7a8ef..bdd3c63e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,7 +18,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] -CUDA = "4" +CUDA = "5" cuDNN = "1" Documenter = "0.27" Flux = "0.14" diff --git a/test/runtests.jl b/test/runtests.jl index b9a05deb..0879790d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,11 +69,11 @@ import CUDA @info "Testing visualization" include("visualization/callbacks.jl") - if VERSION === v"1.10.0" + if VERSION === v"1.10.4" @info "Testing README examples" include("examples.jl") else - @warn "README examples will only be tested in julia version 1.9.2 due to rng dependency. Skipping..." + @warn "README examples will only be tested in julia version 1.10.4 due to rng dependency. Skipping..." end @info "Testing AutoFlux" diff --git a/test/util.jl b/test/util.jl index 9826faa8..c6be1e5a 100644 --- a/test/util.jl +++ b/test/util.jl @@ -280,7 +280,7 @@ end @test learningrate(dm) ≈ 0.08f0 wd = mergeopts(WeightDecay(0.1f0), WeightDecay(2f0), WeightDecay(0.4f0)) - @test wd.gamma ≈ 0.08f0 + @test wd.lambda ≈ 0.08f0 dd = mergeopts(Momentum, Descent(0.1f0), Descent(2f0), Descent(0.4f0)) @test typeof.(dd) == (Descent{Float32}, Descent{Float32}, Descent{Float32})