Skip to content

Commit

Permalink
Merge pull request #32 from JuliaPOMDP/ordered_dictionaries
Browse files Browse the repository at this point in the history
Use OrderedDict
  • Loading branch information
zsunberg authored Jul 8, 2024
2 parents 8545d8a + e79444c commit ed9a93d
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 7 deletions.
134 changes: 134 additions & 0 deletions pr32_benchmarks/LD_disc_o.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#LightDark from POMDPModels.jl, modified for discrete binned observations

using Distributions
# A one-dimensional light-dark problem, originally used to test MCVI
# A very simple POMDP with continuous state and observation spaces.
# maintained by @zsunberg

"""
LightDark1D_DO
A one-dimensional light dark problem. The goal is to be near 0. Observations are noisy measurements of the position.
Model
-----
-3-2-1 0 1 2 3
...| | | | | | | | ...
G S
Here G is the goal. S is the starting location
"""
mutable struct LightDark1D_DO{F<:Function} <: POMDPs.POMDP{LightDark1DState,Int,Float64}
discount_factor::Float64
correct_r::Float64
incorrect_r::Float64
step_size::Float64
movement_cost::Float64
sigma::F
grid::Tuple{Float64, Float64}
grid_step::Float64
grid_len::Int
end

default_sigma(x::Float64) = abs(x - 5)/sqrt(2) + 1e-2

function LightDark1D_DO(;grid = collect(-10:20/40:10))
return LightDark1D_DO(0.9, 10.0, -10.0, 1.0, 0.0, default_sigma,(first(grid),last(grid)),round(grid[2]-grid[1],digits=4),length(grid))
end

POMDPs.discount(p::LightDark1D_DO) = p.discount_factor

POMDPs.isterminal(::LightDark1D_DO, act::Int64) = act == 0

POMDPs.isterminal(::LightDark1D_DO, s::LightDark1DState) = s.status < 0


POMDPs.actions(::LightDark1D_DO) = -1:1

POMDPs.initialstate(pomdp::LightDark1D_DO) = POMDPModels.LDNormalStateDist(2, 3)
POMDPs.initialobs(m::LightDark1D_DO, s) = observation(m, s)

struct DiscreteNormal
dist::Normal{Float64}
g_first::Float64
g_end::Float64
step::Float64
end

function DiscreteNormal(p::LightDark1D_DO,dist::Normal{Float64})
return DiscreteNormal(dist,p.grid[1],p.grid[2],p.grid_step)
end

function Base.rand(rng::AbstractRNG,dist::DiscreteNormal)
val = rand(rng,dist.dist)
# @show val
# @show ceil((val-dist.g_first)/dist.step)*dist.step+dist.g_first
return ceil((val-dist.g_first)/dist.step)*dist.step+dist.g_first
end

function Distributions.pdf(dist::DiscreteNormal,x::Float64)
@assert ceil(round((x-dist.g_first)/dist.step),digits=8)%1.0 == 0.0
discx = x
val = 0.0
if x <= dist.g_first
val = cdf(dist.dist,discx)
elseif x >= dist.g_end
val = 1.0-cdf(dist.dist,discx-dist.step)
else
val = cdf(dist.dist,discx)-cdf(dist.dist,discx-dist.step)
end
return val
end

function POMDPs.observation(p::LightDark1D_DO, sp::LightDark1DState)
return DiscreteNormal(p,Normal(sp.y, p.sigma(sp.y)))
end

# function POMDPs.observation(p::LightDark1D_DO, sp::LightDark1DState)
# dist = Normal(sp.y, p.sigma(sp.y))
# o_vals = zeros(p.grid_len)
# old_cdf = 0.0
# grid = collect(p.grid[1]:p.grid_step:p.grid[2])
# for (i,g) in enumerate(grid)
# if i == 1
# old_cdf = cdf(dist,g)
# o_vals[i] = old_cdf
# elseif i == p.grid_len
# o_vals[end] = 1.0-old_cdf
# else
# new_cdf = cdf(dist,g)
# o_vals[i] = new_cdf-old_cdf
# old_cdf = new_cdf
# end
# end
# # @assert all(o_vals .>= 0)
# # @assert abs(sum(o_vals)-1.0) < 0.0001
# return SparseCat(grid,o_vals)
# end

function POMDPs.transition(p::LightDark1D_DO, s::LightDark1DState, a::Int)
if a == 0
return Deterministic(LightDark1DState(-1, s.y+a*p.step_size))
else
return Deterministic(LightDark1DState(s.status, s.y+a*p.step_size))
end
end

function POMDPs.reward(p::LightDark1D_DO, s::LightDark1DState, a::Int)
if s.status < 0
return 0.0
elseif a == 0
if abs(s.y) < 1
return p.correct_r
else
return p.incorrect_r
end
else
return -p.movement_cost*a
end
end


convert_s(::Type{A}, s::LightDark1DState, p::LightDark1D_DO) where A<:AbstractArray = eltype(A)[s.status, s.y]
convert_s(::Type{LightDark1DState}, s::A, p::LightDark1D_DO) where A<:AbstractArray = LightDark1DState(Int64(s[1]), s[2])
8 changes: 8 additions & 0 deletions pr32_benchmarks/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
67 changes: 67 additions & 0 deletions pr32_benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#Benchmarks for ARDESPOT.jl PR No. 32 (Use OrderedDict)

#Note: Uses HistoryRecorder instead of calls to action directly. This may add additional noise.

using Pkg
Pkg.activate(".")
using ARDESPOT

Pkg.activate("pr32_benchmarks")
using BenchmarkTools

using POMDPs
using POMDPModels
using Random
using POMDPTools
using ParticleFilters


##Baby Tests
pomdp = BabyPOMDP()
pomdp.discount = 1.0
bds = IndependentBounds(DefaultPolicyLB(FeedWhenCrying()), 0.0)
solver = DESPOTSolver(bounds=bds,T_max=Inf,max_trials=1000,rng=MersenneTwister(5))
planner = solve(solver, pomdp)
hr = HistoryRecorder(max_steps=2)
println("BabyPOMDP ===================================")
display(@benchmark simulate(hr, pomdp, planner))
println("")

##Tiger Tests
pomdp = TigerPOMDP()

solver = DESPOTSolver(bounds=(-20.0, 0.0),T_max=Inf,max_trials=1000,rng=MersenneTwister(5))
planner = solve(solver, pomdp)

hr = HistoryRecorder(max_steps=3)
println("Tiger ===================================")
display(@benchmark simulate(hr, pomdp, planner))
println("")

##LightDark POMDP
include("LD_disc_o.jl")
pomdp = LightDark1D_DO(;grid = collect(-10:20/40:10))

lb_pol = FunctionPolicy(b->-1)
bds = IndependentBounds(DefaultPolicyLB(lb_pol), pomdp.correct_r,check_terminal=true)
solver = DESPOTSolver(bounds=bds,T_max=Inf,max_trials=1000,rng=MersenneTwister(5))
planner = solve(solver, pomdp)
hr = HistoryRecorder(max_steps=2)
println("LightDark D.O. - 40 Obs ===================================")
display(@benchmark simulate(hr, pomdp, planner))
println("")

##LD 2
pomdp = LightDark1D_DO(;grid = collect(-10:20/100:10))

lb_pol = FunctionPolicy(b->-1)
bds = IndependentBounds(DefaultPolicyLB(lb_pol), pomdp.correct_r,check_terminal=true)
solver = DESPOTSolver(bounds=bds,T_max=Inf,max_trials=1000,rng=MersenneTwister(5))
planner = solve(solver, pomdp)
hr = HistoryRecorder(max_steps=2)
println("LightDark D.O. - 100 Obs ===================================")
display(@benchmark simulate(hr, pomdp, planner))
println("")

#Add RockSample???
"done"
63 changes: 63 additions & 0 deletions pr32_benchmarks/ld_policies.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#LD Policies for Bounds
struct LDPolicy{M<:POMDP,S}<:Policy
m::M
particle_dict::Dict{S,Float64}
act::Int64
planner::DPWPlanner
end

function LDPolicy(m::LightDark1D{S}) where S
solver = DPWSolver(n_iterations=50, depth=20, exploration_constant=5.0)
planner = solve(solver, UnderlyingMDP(m))
return LDPolicy(m,Dict{S,Float64}(),-1,planner)
end

##Tyler's map_bel:
@inline function get_incr!(h::Dict{K,V}, key::K, v) where {K,V} # modified from dict.jl
index = Base.ht_keyindex2!(h, key)

v = convert(V, v)
if index > 0
h.age += 1
return h.vals[index] += 1
end

age0 = h.age
if h.age != age0
index = Base.ht_keyindex2!(h, key)
end
if index > 0
h.age += 1
@inbounds h.keys[index] = key
@inbounds h.vals[index] = v
else
@inbounds Base._setindex!(h, v, key, -index)
end
return v
end

function map_bel(b::AbstractParticleBelief, pol)
empty!(pol.particle_dict)
dict = pol.particle_dict
max_o = 0.0
# max_state = first(particles(b))
max_state = pol.m.state1
for (p,w) in weighted_particles(b)
n = get_incr!(dict, p, w)
if n > max_o
max_o = n
max_state = p
end
end
return max_state
end

function POMDPs.action(policy::LDPolicy,s::Union{ParticleCollection,ScenarioBelief})
max_p = map_bel(s,policy)
return POMDPs.action(policy,max_p)
end

function POMDPs.action(policy::LDPolicy,s::POMDPModels.LDNormalStateDist)
max_p = map_bel(s,policy)
return POMDPs.action(policy,max_p)
end
6 changes: 4 additions & 2 deletions src/default_policy_sim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ function branching_sim(pomdp::POMDP, policy::Policy, b::ScenarioBelief, steps::I
S = statetype(pomdp)
O = obstype(pomdp)
odict = Dict{O, Vector{Pair{Int, S}}}()
olist = O[]

if steps <= 0
return length(b.scenarios)*fval(pomdp, b)
Expand All @@ -19,14 +20,15 @@ function branching_sim(pomdp::POMDP, policy::Policy, b::ScenarioBelief, steps::I
push!(odict[o], k=>sp)
else
odict[o] = [k=>sp]
push!(olist,o)
end

r_sum += r
end
end

next_r = 0.0
for (o, scenarios) in odict
for o in olist
scenarios = odict[o]
bp = ScenarioBelief(scenarios, b.random_source, b.depth+1, o)
if length(scenarios) == 1
next_r += rollout(pomdp, policy, bp, steps-1, fval)
Expand Down
4 changes: 2 additions & 2 deletions src/planner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ function build_despot(p::DESPOTPlanner, b_0)
D = DESPOT(p, b_0)
b = 1
trial = 1
start = CPUtime_us()
start = time()

while D.mu[1]-D.l[1] > p.sol.epsilon_0 &&
CPUtime_us()-start < p.sol.T_max*1e6 &&
time()-start < p.sol.T_max &&
trial <= p.sol.max_trials
b = explore!(D, 1, p)
backup!(D, b, p)
Expand Down
9 changes: 7 additions & 2 deletions src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ function expand!(D::DESPOT, b::Int, p::DESPOTPlanner)
A = actiontype(p.pomdp)
O = obstype(p.pomdp)
odict = Dict{O, Int}()
olist = O[]

belief = get_belief(D, b, p.rs)
for a in actions(p.pomdp, belief)
empty!(odict)
empty!(olist)
rsum = 0.0

for scen in D.scenarios[b]
Expand All @@ -74,12 +76,13 @@ function expand!(D::DESPOT, b::Int, p::DESPOTPlanner)
push!(D.scenarios, Vector{Pair{Int, S}}())
bp = length(D.scenarios)
odict[o] = bp
push!(olist,o)
end
push!(D.scenarios[bp], first(scen)=>sp)
end
end

push!(D.ba_children, collect(values(odict)))
push!(D.ba_children, [odict[o] for o in olist])
ba = length(D.ba_children)
push!(D.ba_action, a)
push!(D.children[b], ba)
Expand All @@ -89,7 +92,9 @@ function expand!(D::DESPOT, b::Int, p::DESPOTPlanner)

nbps = length(odict)
resize!(D, length(D.children) + nbps)
for (o, bp) in odict
for o in olist
bp = odict[o]

D.obs[bp] = o
D.children[bp] = Int[]
D.parent_b[bp] = b
Expand Down
2 changes: 1 addition & 1 deletion test/baby_sanity_check.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using POMDPs
using ARDESPOT
using POMDPToolbox
using POMDPTools
using POMDPModels
using ProgressMeter

Expand Down

0 comments on commit ed9a93d

Please sign in to comment.