Skip to content

Commit

Permalink
added some tests for too many and too few init params
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde authored Sep 13, 2023
1 parent 8808521 commit e2a04ba
Showing 1 changed file with 87 additions and 18 deletions.
105 changes: 87 additions & 18 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,18 @@
end

# initial parameters
init_params = [(b=randn(), a=rand()) for _ in 1:100]
nchains = 100
init_params = [(b=randn(), a=rand()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
100;
nchains;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -184,14 +185,36 @@
MySampler(),
MCMCThreads(),
3,
100;
nchains;
progress=false,
init_params=Iterators.repeated(init_params),
init_params=Iterators.repeated(init_params, nchains),
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains + 1),
)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains - 1),
)
end

@testset "Multicore sampling" begin
Expand Down Expand Up @@ -274,17 +297,18 @@
@test all(l.level > Logging.LogLevel(-1) for l in logs)

# initial parameters
init_params = [(a=randn(), b=rand()) for _ in 1:100]
nchains = 100
init_params = [(a=randn(), b=rand()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
100;
nchains;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -296,15 +320,37 @@
MySampler(),
MCMCDistributed(),
3,
100;
nchains;
progress=false,
init_params=Iterators.repeated(init_params),
init_params=Iterators.repeated(init_params, nchains),
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCThrMCMCDistributedeads(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains + 1),
)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains - 1),
)

# Remove workers
rmprocs(pids...)
end
Expand Down Expand Up @@ -360,17 +406,18 @@
@test all(l.level > Logging.LogLevel(-1) for l in logs)

# initial parameters
init_params = [(a=rand(), b=randn()) for _ in 1:100]
nchains = 100
init_params = [(a=rand(), b=randn()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
100;
nchains;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -382,14 +429,36 @@
MySampler(),
MCMCSerial(),
3,
100;
nchains;
progress=false,
init_params=Iterators.repeated(init_params),
init_params=Iterators.repeated(init_params, nchains),
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains + 1),
)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains - 1),
)
end

@testset "Ensemble sampling: Reproducibility" begin
Expand Down

0 comments on commit e2a04ba

Please sign in to comment.