-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #32 from JuliaPOMDP/ordered_dictionaries
Use OrderedDict
- Loading branch information
Showing
8 changed files
with
286 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
#LightDark from POMDPModels.jl, modified for discrete binned observations | ||
|
||
using Distributions | ||
# A one-dimensional light-dark problem, originally used to test MCVI | ||
# A very simple POMDP with continuous state and observation spaces. | ||
# maintained by @zsunberg | ||
|
||
""" | ||
LightDark1D_DO | ||
A one-dimensional light dark problem. The goal is to be near 0. Observations are noisy measurements of the position. | ||
Model | ||
----- | ||
-3-2-1 0 1 2 3 | ||
...| | | | | | | | ... | ||
G S | ||
Here G is the goal. S is the starting location | ||
""" | ||
mutable struct LightDark1D_DO{F<:Function} <: POMDPs.POMDP{LightDark1DState,Int,Float64} | ||
discount_factor::Float64 | ||
correct_r::Float64 | ||
incorrect_r::Float64 | ||
step_size::Float64 | ||
movement_cost::Float64 | ||
sigma::F | ||
grid::Tuple{Float64, Float64} | ||
grid_step::Float64 | ||
grid_len::Int | ||
end | ||
|
||
default_sigma(x::Float64) = abs(x - 5)/sqrt(2) + 1e-2 | ||
|
||
function LightDark1D_DO(;grid = collect(-10:20/40:10)) | ||
return LightDark1D_DO(0.9, 10.0, -10.0, 1.0, 0.0, default_sigma,(first(grid),last(grid)),round(grid[2]-grid[1],digits=4),length(grid)) | ||
end | ||
|
||
POMDPs.discount(p::LightDark1D_DO) = p.discount_factor | ||
|
||
POMDPs.isterminal(::LightDark1D_DO, act::Int64) = act == 0 | ||
|
||
POMDPs.isterminal(::LightDark1D_DO, s::LightDark1DState) = s.status < 0 | ||
|
||
|
||
POMDPs.actions(::LightDark1D_DO) = -1:1 | ||
|
||
POMDPs.initialstate(pomdp::LightDark1D_DO) = POMDPModels.LDNormalStateDist(2, 3) | ||
POMDPs.initialobs(m::LightDark1D_DO, s) = observation(m, s) | ||
|
||
struct DiscreteNormal | ||
dist::Normal{Float64} | ||
g_first::Float64 | ||
g_end::Float64 | ||
step::Float64 | ||
end | ||
|
||
function DiscreteNormal(p::LightDark1D_DO,dist::Normal{Float64}) | ||
return DiscreteNormal(dist,p.grid[1],p.grid[2],p.grid_step) | ||
end | ||
|
||
function Base.rand(rng::AbstractRNG,dist::DiscreteNormal) | ||
val = rand(rng,dist.dist) | ||
# @show val | ||
# @show ceil((val-dist.g_first)/dist.step)*dist.step+dist.g_first | ||
return ceil((val-dist.g_first)/dist.step)*dist.step+dist.g_first | ||
end | ||
|
||
function Distributions.pdf(dist::DiscreteNormal,x::Float64) | ||
@assert ceil(round((x-dist.g_first)/dist.step),digits=8)%1.0 == 0.0 | ||
discx = x | ||
val = 0.0 | ||
if x <= dist.g_first | ||
val = cdf(dist.dist,discx) | ||
elseif x >= dist.g_end | ||
val = 1.0-cdf(dist.dist,discx-dist.step) | ||
else | ||
val = cdf(dist.dist,discx)-cdf(dist.dist,discx-dist.step) | ||
end | ||
return val | ||
end | ||
|
||
function POMDPs.observation(p::LightDark1D_DO, sp::LightDark1DState) | ||
return DiscreteNormal(p,Normal(sp.y, p.sigma(sp.y))) | ||
end | ||
|
||
# function POMDPs.observation(p::LightDark1D_DO, sp::LightDark1DState) | ||
# dist = Normal(sp.y, p.sigma(sp.y)) | ||
# o_vals = zeros(p.grid_len) | ||
# old_cdf = 0.0 | ||
# grid = collect(p.grid[1]:p.grid_step:p.grid[2]) | ||
# for (i,g) in enumerate(grid) | ||
# if i == 1 | ||
# old_cdf = cdf(dist,g) | ||
# o_vals[i] = old_cdf | ||
# elseif i == p.grid_len | ||
# o_vals[end] = 1.0-old_cdf | ||
# else | ||
# new_cdf = cdf(dist,g) | ||
# o_vals[i] = new_cdf-old_cdf | ||
# old_cdf = new_cdf | ||
# end | ||
# end | ||
# # @assert all(o_vals .>= 0) | ||
# # @assert abs(sum(o_vals)-1.0) < 0.0001 | ||
# return SparseCat(grid,o_vals) | ||
# end | ||
|
||
function POMDPs.transition(p::LightDark1D_DO, s::LightDark1DState, a::Int) | ||
if a == 0 | ||
return Deterministic(LightDark1DState(-1, s.y+a*p.step_size)) | ||
else | ||
return Deterministic(LightDark1DState(s.status, s.y+a*p.step_size)) | ||
end | ||
end | ||
|
||
function POMDPs.reward(p::LightDark1D_DO, s::LightDark1DState, a::Int) | ||
if s.status < 0 | ||
return 0.0 | ||
elseif a == 0 | ||
if abs(s.y) < 1 | ||
return p.correct_r | ||
else | ||
return p.incorrect_r | ||
end | ||
else | ||
return -p.movement_cost*a | ||
end | ||
end | ||
|
||
|
||
convert_s(::Type{A}, s::LightDark1DState, p::LightDark1D_DO) where A<:AbstractArray = eltype(A)[s.status, s.y] | ||
convert_s(::Type{LightDark1DState}, s::A, p::LightDark1D_DO) where A<:AbstractArray = LightDark1DState(Int64(s[1]), s[2]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
[deps] | ||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" | ||
POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" | ||
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" | ||
ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#Benchmarks for ARDESPOT.jl PR No. 32 (Use OrderedDict) | ||
|
||
#Note: Uses HistoryRecorder instead of calls to action directly. This may add additional noise. | ||
|
||
using Pkg | ||
Pkg.activate(".") | ||
using ARDESPOT | ||
|
||
Pkg.activate("pr32_benchmarks") | ||
using BenchmarkTools | ||
|
||
using POMDPs | ||
using POMDPModels | ||
using Random | ||
using POMDPTools | ||
using ParticleFilters | ||
|
||
|
||
##Baby Tests | ||
pomdp = BabyPOMDP() | ||
pomdp.discount = 1.0 | ||
bds = IndependentBounds(DefaultPolicyLB(FeedWhenCrying()), 0.0) | ||
solver = DESPOTSolver(bounds=bds,T_max=Inf,max_trials=1000,rng=MersenneTwister(5)) | ||
planner = solve(solver, pomdp) | ||
hr = HistoryRecorder(max_steps=2) | ||
println("BabyPOMDP ===================================") | ||
display(@benchmark simulate(hr, pomdp, planner)) | ||
println("") | ||
|
||
##Tiger Tests | ||
pomdp = TigerPOMDP() | ||
|
||
solver = DESPOTSolver(bounds=(-20.0, 0.0),T_max=Inf,max_trials=1000,rng=MersenneTwister(5)) | ||
planner = solve(solver, pomdp) | ||
|
||
hr = HistoryRecorder(max_steps=3) | ||
println("Tiger ===================================") | ||
display(@benchmark simulate(hr, pomdp, planner)) | ||
println("") | ||
|
||
##LightDark POMDP | ||
include("LD_disc_o.jl") | ||
pomdp = LightDark1D_DO(;grid = collect(-10:20/40:10)) | ||
|
||
lb_pol = FunctionPolicy(b->-1) | ||
bds = IndependentBounds(DefaultPolicyLB(lb_pol), pomdp.correct_r,check_terminal=true) | ||
solver = DESPOTSolver(bounds=bds,T_max=Inf,max_trials=1000,rng=MersenneTwister(5)) | ||
planner = solve(solver, pomdp) | ||
hr = HistoryRecorder(max_steps=2) | ||
println("LightDark D.O. - 40 Obs ===================================") | ||
display(@benchmark simulate(hr, pomdp, planner)) | ||
println("") | ||
|
||
##LD 2 | ||
pomdp = LightDark1D_DO(;grid = collect(-10:20/100:10)) | ||
|
||
lb_pol = FunctionPolicy(b->-1) | ||
bds = IndependentBounds(DefaultPolicyLB(lb_pol), pomdp.correct_r,check_terminal=true) | ||
solver = DESPOTSolver(bounds=bds,T_max=Inf,max_trials=1000,rng=MersenneTwister(5)) | ||
planner = solve(solver, pomdp) | ||
hr = HistoryRecorder(max_steps=2) | ||
println("LightDark D.O. - 100 Obs ===================================") | ||
display(@benchmark simulate(hr, pomdp, planner)) | ||
println("") | ||
|
||
#Add RockSample??? | ||
"done" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#LD Policies for Bounds | ||
struct LDPolicy{M<:POMDP,S}<:Policy | ||
m::M | ||
particle_dict::Dict{S,Float64} | ||
act::Int64 | ||
planner::DPWPlanner | ||
end | ||
|
||
function LDPolicy(m::LightDark1D{S}) where S | ||
solver = DPWSolver(n_iterations=50, depth=20, exploration_constant=5.0) | ||
planner = solve(solver, UnderlyingMDP(m)) | ||
return LDPolicy(m,Dict{S,Float64}(),-1,planner) | ||
end | ||
|
||
##Tyler's map_bel: | ||
@inline function get_incr!(h::Dict{K,V}, key::K, v) where {K,V} # modified from dict.jl | ||
index = Base.ht_keyindex2!(h, key) | ||
|
||
v = convert(V, v) | ||
if index > 0 | ||
h.age += 1 | ||
return h.vals[index] += 1 | ||
end | ||
|
||
age0 = h.age | ||
if h.age != age0 | ||
index = Base.ht_keyindex2!(h, key) | ||
end | ||
if index > 0 | ||
h.age += 1 | ||
@inbounds h.keys[index] = key | ||
@inbounds h.vals[index] = v | ||
else | ||
@inbounds Base._setindex!(h, v, key, -index) | ||
end | ||
return v | ||
end | ||
|
||
function map_bel(b::AbstractParticleBelief, pol) | ||
empty!(pol.particle_dict) | ||
dict = pol.particle_dict | ||
max_o = 0.0 | ||
# max_state = first(particles(b)) | ||
max_state = pol.m.state1 | ||
for (p,w) in weighted_particles(b) | ||
n = get_incr!(dict, p, w) | ||
if n > max_o | ||
max_o = n | ||
max_state = p | ||
end | ||
end | ||
return max_state | ||
end | ||
|
||
function POMDPs.action(policy::LDPolicy,s::Union{ParticleCollection,ScenarioBelief}) | ||
max_p = map_bel(s,policy) | ||
return POMDPs.action(policy,max_p) | ||
end | ||
|
||
function POMDPs.action(policy::LDPolicy,s::POMDPModels.LDNormalStateDist) | ||
max_p = map_bel(s,policy) | ||
return POMDPs.action(policy,max_p) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
using POMDPs | ||
using ARDESPOT | ||
using POMDPToolbox | ||
using POMDPTools | ||
using POMDPModels | ||
using ProgressMeter | ||
|
||
|