diff --git a/docs/src/sim.md b/docs/src/sim.md index 328d570..e6fe966 100644 --- a/docs/src/sim.md +++ b/docs/src/sim.md @@ -11,7 +11,9 @@ end ``` This allows a flexible and general way to interact with a POMDP environment without creating new `Policy` types. -Note: by default, since there is no observation before the first action, on the first call to the `do` block, `obs` is `nothing`. +In the POMDP case, an updater can optionally be supplied as an additional positional argument if the policy function works with beliefs rather than directly with observations. + +More examples can be found in the [POMDPExamples Package](https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Running-Simulations.ipynb) More examples can be found in the [POMDPExamples Package](https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Running-Simulations.ipynb) diff --git a/src/sim.jl b/src/sim.jl index 2288fd5..8bb3416 100644 --- a/src/sim.jl +++ b/src/sim.jl @@ -2,8 +2,8 @@ # maintained by @zsunberg """ - sim(polfunc::Function, mdp::MDP[, initial_state]; [kwargs...]) - sim(polfunc::Function, pomdp::POMDP[, initial_state]) + sim(polfunc::Function, mdp::MDP; []) + sim(polfunc::Function, pomdp::POMDP; []) Alternative way of running a simulation with a function specifying how to calculate the action at each timestep. @@ -18,44 +18,54 @@ Alternative way of running a simulation with a function specifying how to calcul for an MDP or sim(pomdp) do o - # code that does belief updates with observation `o` and calculates `a` - # you can also do other things like display something + # code that calculates 'a' based on observation `o` + # optionally you could save 'o' in a global variable or do a belief update + return a + end + +or with a POMDP + + sim(pomdp, updater) do b + # code that calculates 'a' based on belief `b` + # `b` is calculated by `updater` return a end -for a POMDP. +for a POMDP and a belief updater. # Keyword Arguments -Use the `simulator` keyword argument to specify any simulator to run the simulation. If nothing is specified for the simulator, a HistoryRecorder will be used as the simulator, with all keyword arguments forwarded to it, e.g. +## All Versions - sim(mdp, max_steps=100, show_progress=true) do s - # ... - end +- `initialstate`: the initial state for the simulation +- `simulator`: keyword argument to specify any simulator to run the simulation. If nothing is specified for the simulator, a HistoryRecorder will be used as the simulator, with all keyword arguments forwarded to it, e.g. + ``` + sim(mdp, max_steps=100, show_progress=true) do s + # ... + end + ``` + will limit the simulation to 100 steps. + +## POMDP version -will limit the simulation to 100 steps. +- `initialobs`: this will control the initial observation given to the policy function. If this is not defined, `generate_o(m, s, rng)` will be used if it is available. If it is not, `missing` will be used. -The POMDP version also has two additional keyword arguments: -- `initialobs`: this will control the initial observation given to the policy function. -- `updater`: if provided, this updater will be used to update the belief, and the belief will be used as the argument to the policy function. If a custom updater is provided, the `initialobs` keyword argument should be used to specify the initial belief. +## POMDP and updater version + +- `initialbelief`: `initialize_belief(updater, initialbelief)` is the first belief that will be given to the policy function. """ function sim end -function sim(polfunc::Function, mdp::MDP, - initialstate=nothing; +function sim(polfunc::Function, mdp::MDP; + initialstate=nothing, simulator=nothing, kwargs... ) kwargd = Dict(kwargs) if initialstate==nothing && statetype(mdp) != Nothing - if haskey(kwargd, :initialstate) - initialstate = pop!(kwargd, :initialstate) - else - initialstate = default_init_state(mdp) - end + initialstate = default_init_state(mdp) end - delete!(kwargd, :initialstate) if simulator==nothing simulator = HistoryRecorder(;kwargd...) end @@ -63,36 +73,46 @@ function sim(polfunc::Function, mdp::MDP, simulate(simulator, mdp, policy, initialstate) end -function sim(polfunc::Function, pomdp::POMDP, - initialstate=nothing; +function sim(polfunc::Function, pomdp::POMDP; + initialstate=nothing, simulator=nothing, initialobs=nothing, - updater=nothing, kwargs... ) kwargd = Dict(kwargs) if initialstate==nothing && statetype(pomdp) != Nothing - if haskey(kwargd, :initialstate) - initialstate = pop!(kwargd, :initialstate) - else - initialstate = default_init_state(pomdp) - end + initialstate = default_init_state(pomdp) end - delete!(kwargd, :initialstate) if simulator==nothing simulator = HistoryRecorder(;kwargd...) end + updater = PreviousObservationUpdater() if initialobs==nothing && obstype(pomdp) != Nothing initialobs = default_init_obs(pomdp, initialstate) end - if updater==nothing - updater = PreviousObservationUpdater() - end policy = FunctionPolicy(polfunc) simulate(simulator, pomdp, policy, updater, initialobs, initialstate) end +function sim(polfunc::Function, pomdp::POMDP, updater::Updater; + initialstate=nothing, + simulator=nothing, + initialbelief=initialstate_distribution(pomdp), + kwargs... + ) + + kwargd = Dict(kwargs) + if initialstate==nothing && statetype(pomdp) != Nothing + initialstate = default_init_state(pomdp) + end + if simulator==nothing + simulator = HistoryRecorder(;kwargd...) + end + policy = FunctionPolicy(polfunc) + simulate(simulator, pomdp, policy, updater, initialbelief, initialstate) +end + function default_init_obs(p::POMDP, s) if implemented(generate_o, Tuple{typeof(p), typeof(s), typeof(Random.GLOBAL_RNG)}) return generate_o(p, s, Random.GLOBAL_RNG) @@ -109,9 +129,11 @@ end error(""" Error in sim(::$(typeof(p))): No initial state specified. - Please supply it as an argument after the mdp or define the method POMDPs.initialstate(::$(typeof(p)), ::$(typeof(Random.GLOBAL_RNG))) or define the method POMDPs.initialstate_distribution(::$(typeof(p))). + Please supply it as a keyword argument or define the method POMDPs.initialstate(::$(typeof(p)), ::$(typeof(Random.GLOBAL_RNG))) or define the method POMDPs.initialstate_distribution(::$(typeof(p))). """) end end end + +@deprecate sim(f::Function, m::Union{POMDP, MDP}, initialstate; kwargs...) sim(f, m; initialstate=initialstate, kwargs...) diff --git a/test/test_sim.jl b/test/test_sim.jl index cc72791..631a065 100644 --- a/test/test_sim.jl +++ b/test/test_sim.jl @@ -12,7 +12,7 @@ end @testset "BabyPOMDP sim" begin pomdp = BabyPOMDP() - hist = sim(pomdp, max_steps=100) do obs + hist = sim(pomdp, max_steps=100, initialobs=false) do obs @assert isa(obs, Bool) acts = actions(pomdp) return rand(acts) @@ -20,16 +20,21 @@ end @test length(hist) == 100 hist = sim(pomdp, false, max_steps=100) do obs - @assert isa(obs, Bool) acts = actions(pomdp) return rand(acts) end @test length(hist) == 100 hist = sim(pomdp, initialstate=true, max_steps=100) do obs - @assert isa(obs, Bool) acts = actions(pomdp) return rand(acts) end @test length(hist) == 100 + + hist = sim(pomdp, max_steps=100, DiscreteUpdater(pomdp)) do b + @assert isa(b, DiscreteBelief) + acts = actions(pomdp) + return rand(acts) + end + end