Skip to content

Commit

Permalink
Merge pull request #19 from JuliaPOMDP/solver_options
Browse files Browse the repository at this point in the history
Solver options
  • Loading branch information
MaximeBouton authored Jul 23, 2018
2 parents 86310d8 + 4e3154e commit d4d56b8
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 180 deletions.
1 change: 0 additions & 1 deletion src/DiscreteValueIteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import POMDPs: Solver, solve, Policy, action, value
export
ValueIterationPolicy,
ValueIterationSolver,
create_policy,
solve,
action,
value,
Expand Down
9 changes: 3 additions & 6 deletions src/docs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ The solver type. Contains two parameters that are keyworded in the constructor:
- max_iterations::Int64, the maximum number of iterations value iteration runs for (default 100)
- belres::Float64, the Bellman residual (default 1e-3)
- verbose::Bool, if set to true, the bellman residual and the time per iteration will be printed to STDOUT (default false)
- include_Q::Bool, if set to true, the solver outputs the Q values in addition to the utility and the policy (default true)
- init_util::Vector{Float64}, provides a custom initialization of the utility vector. (initializes utility to 0 by default)
The solver can be initialized by running:
`solver = ValueIterationSolver(max_iterations=1000, belres=1e-6)`
Expand All @@ -24,12 +27,6 @@ The Q-matrix is nxm, where n is the number of states and m is the number of acti
ValueIterationPolicy


"""
Returns an empty policy
"""
create_policy


"""
Computes the optimal policy for an MDP. The function takes a verbose flag which can dump text output onto the screen.
You can run the function:
Expand Down
129 changes: 66 additions & 63 deletions src/vanilla.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
mutable struct ValueIterationSolver <: Solver
max_iterations::Int64 # max number of iterations
belres::Float64 # the Bellman Residual
verbose::Bool
include_Q::Bool
init_util::Vector{Float64}
end
# Default constructor
function ValueIterationSolver(;max_iterations::Int64=100, belres::Float64=1e-3)
return ValueIterationSolver(max_iterations, belres)
function ValueIterationSolver(;max_iterations::Int64 = 100,
belres::Float64 = 1e-3,
verbose::Bool = false,
include_Q::Bool = true,
init_util::Vector{Float64}=Vector{Float64}(0))
return ValueIterationSolver(max_iterations, belres, verbose, include_Q, init_util)
end

# The policy type
Expand All @@ -16,60 +23,41 @@ mutable struct ValueIterationPolicy <: Policy
action_map::Vector # Maps the action index to the concrete action type
include_Q::Bool # Flag for including the Q-matrix
mdp::Union{MDP,POMDP} # uses the model for indexing in the action function
# constructor with an optinal initial value function argument
function ValueIterationPolicy(mdp::Union{MDP,POMDP};
utility::Vector{Float64}=Array{Float64}(0),
include_Q::Bool=true)
ns = n_states(mdp)
na = n_actions(mdp)
self = new()
if !isempty(utility)
@assert first(size(utility)) == ns "Input utility dimension mismatch"
self.util = utility
else
self.util = zeros(ns)
end
self.action_map = ordered_actions(mdp)
self.policy = zeros(Int64,ns)
include_Q ? self.qmat = zeros(ns,na) : self.qmat = zeros(0,0)
self.include_Q = include_Q
self.mdp = mdp
return self
end
# constructor for solved q, util and policy
function ValueIterationPolicy(mdp::Union{MDP,POMDP}, q::Matrix{Float64}, util::Vector{Float64}, policy::Vector{Int64})
self = new()
self.qmat = q
self.util = util
self.policy = policy
self.action_map = ordered_actions(mdp)
self.include_Q = true
self.mdp = mdp
return self
end
# constructor for defualt Q-matrix
function ValueIterationPolicy(mdp::Union{MDP,POMDP}, q::Matrix{Float64})
(ns, na) = size(q)
p = zeros(ns)
u = zeros(ns)
for i = 1:ns
p[i] = indmax(q[i,:])
u[i] = maximum(q[i,:])
end
self = new()
self.qmat = q
self.util = u
self.policy = p
self.action_map = ordered_actions(mdp)
self.include_Q = true
self.mdp = mdp
return self
end
end

# returns a default value iteration policy
function create_policy(solver::ValueIterationSolver, mdp::Union{MDP,POMDP})
return ValueIterationPolicy(mdp)
# constructor with an optinal initial value function argument
function ValueIterationPolicy(mdp::Union{MDP,POMDP};
utility::Vector{Float64}=zeros(n_states(mdp)),
policy::Vector{Int64}=zeros(Int64, n_states(mdp)),
include_Q::Bool=true)
ns = n_states(mdp)
na = n_actions(mdp)
@assert length(utility) == ns "Input utility dimension mismatch"
@assert length(policy) == ns "Input policy dimension mismatch"
action_map = ordered_actions(mdp)
include_Q ? qmat = zeros(ns,na) : qmat = zeros(0,0)
return ValueIterationPolicy(qmat, utility, policy, action_map, include_Q, mdp)
end

# constructor for solved q, util and policy
function ValueIterationPolicy(mdp::Union{MDP,POMDP}, q::Matrix{Float64}, util::Vector{Float64}, policy::Vector{Int64})
action_map = ordered_actions(mdp)
include_Q = true
return ValueIterationPolicy(q, util, policy, action_map, include_Q, mdp)
end

# constructor for default Q-matrix
function ValueIterationPolicy(mdp::Union{MDP,POMDP}, q::Matrix{Float64})
(ns, na) = size(q)
p = zeros(ns)
u = zeros(ns)
for i = 1:ns
p[i] = indmax(q[i,:])
u[i] = maximum(q[i,:])
end
action_map = ordered_actions(mdp)
include_Q = true
return ValueIterationPolicy(q, u, p, action_map, include_Q, mdp)
end

# returns the fields of the policy type
Expand Down Expand Up @@ -114,23 +102,34 @@ end
# policy = ValueIterationPolicy(mdp)
# solve(solver, mdp, policy, verbose=true)
#####################################################################
function solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}, policy=create_policy(solver, mdp); verbose::Bool=false)

function solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}; kwargs...)

# deprecation warning - can be removed when Julia 1.0 is adopted
if !isempty(kwargs)
warn("Keyword args for solve(::ValueIterationSolver, ::MDP) are no longer supported. For verbose output, use the verbose option in the ValueIterationSolver")
end

@warn_requirements solve(solver, mdp)

# solver parameters
max_iterations = solver.max_iterations
belres = solver.belres
discount_factor = discount(mdp)
ns = n_states(mdp)
na = n_actions(mdp)

# intialize the utility and Q-matrix
util = policy.util
qmat = policy.qmat
include_Q = policy.include_Q
if !isempty(solver.init_util)
@assert length(solver.init_util) == ns "Input utility dimension mismatch"
util = solver.init_util
else
util = zeros(ns)
end
include_Q = solver.include_Q
if include_Q
qmat[:] = 0.0
qmat = zeros(ns, na)
end
pol = policy.policy
pol = zeros(Int64, ns)

total_time = 0.0
iter_time = 0.0
Expand Down Expand Up @@ -178,10 +177,14 @@ function solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}, policy=creat
end # state
iter_time = toq()
total_time += iter_time
verbose ? @printf("[Iteration %-4d] residual: %10.3G | iteration runtime: %10.3f ms, (%10.3G s total)\n", i, residual, iter_time*1000.0, total_time) : nothing
solver.verbose ? @printf("[Iteration %-4d] residual: %10.3G | iteration runtime: %10.3f ms, (%10.3G s total)\n", i, residual, iter_time*1000.0, total_time) : nothing
residual < belres ? break : nothing
end # main
policy
if include_Q
return ValueIterationPolicy(mdp, qmat, util, pol)
else
return ValueIterationPolicy(mdp, utility=util, policy=pol, include_Q=false)
end
end

function action(policy::ValueIterationPolicy, s::S) where S
Expand Down
110 changes: 74 additions & 36 deletions test/runtests_basic_value_iteration.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@

function support_serial_qtest(mdp::Union{MDP,POMDP}, file::AbstractString; niter::Int64=100, res::Float64=1e-3)
qt = readdlm(file)
solver = ValueIterationSolver(max_iterations=niter, belres=res)
policy = create_policy(solver, mdp)
policy = solve(solver, mdp, policy, verbose=true)
solver = ValueIterationSolver(max_iterations=niter, belres=res, verbose=true)
policy = solve(solver, mdp)
(q, u, p, am) = locals(policy)
npolicy = ValueIterationPolicy(mdp, deepcopy(q))
nnpolicy = ValueIterationPolicy(mdp, deepcopy(q), deepcopy(u), deepcopy(p))
Expand All @@ -17,43 +16,82 @@ end


function test_complex_gridworld()
# Load correct policy from file and verify we can reconstruct it
rstates = [GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)]
rvals = [-10.0, -5.0, 10.0, 3.0]
xs = 10
ys = 10
mdp = GridWorld(sx=xs, sy=ys, rs = rstates, rv = rvals)
file = "grid-world-10x10-Q-matrix.txt"
niter = 100
res = 1e-3

return support_serial_qtest(mdp, file, niter=niter, res=res)
# Load correct policy from file and verify we can reconstruct it
rstates = [GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)]
rvals = [-10.0, -5.0, 10.0, 3.0]
xs = 10
ys = 10
mdp = GridWorld(sx=xs, sy=ys, rs = rstates, rv = rvals)
file = "grid-world-10x10-Q-matrix.txt"
niter = 100
res = 1e-3

return support_serial_qtest(mdp, file, niter=niter, res=res)
end

function test_simple_grid()
# Simple test....
# GridWorld(sx=2,sy=3) w reward at (2,3):
# Here's our grid:
# |state (x,y)____available actions__|
# ----------------------------------------------
# |5 (1,3)__u,d,l,r__|6 (2,3)__u,d,l,r+REWARD__|
# |3 (1,2)__u,d,l,r__|4 (2,2)______u,d,l,r_____|
# |1 (1,1)__u,d,l,r__|2 (2,1)______u,d,l,r_____|
# ----------------------------------------------
# 7 (0,0) is absorbing state
mdp = GridWorld(sx=2, sy=3, rs = [GridWorldState(2,3)], rv = [10.0])

solver = ValueIterationSolver()
policy = create_policy(solver, mdp)
policy = solve(solver, mdp, policy, verbose=true)

# up: 1, down: 2, left: 3, right: 4
correct_policy = [1,1,1,1,4,1,1] # alternative policies
# are possible, but since they are tied & the first
# action is always 1, we will always return 1 for tied
# actions
return policy.policy == correct_policy
# Simple test....
# GridWorld(sx=2,sy=3) w reward at (2,3):
# Here's our grid:
# |state (x,y)____available actions__|
# ----------------------------------------------
# |5 (1,3)__u,d,l,r__|6 (2,3)__u,d,l,r+REWARD__|
# |3 (1,2)__u,d,l,r__|4 (2,2)______u,d,l,r_____|
# |1 (1,1)__u,d,l,r__|2 (2,1)______u,d,l,r_____|
# ----------------------------------------------
# 7 (0,0) is absorbing state
mdp = GridWorld(sx=2, sy=3, rs = [GridWorldState(2,3)], rv = [10.0])

solver = ValueIterationSolver(verbose=true)
policy = solve(solver, mdp)

# up: 1, down: 2, left: 3, right: 4
correct_policy = [1,1,1,1,4,1,1] # alternative policies
# are possible, but since they are tied & the first
# action is always 1, we will always return 1 for tied
# actions
return policy.policy == correct_policy
end

function test_init_solution()
# Initialize the value to the solution
rstates = [GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)]
rvals = [-10.0, -5.0, 10.0, 3.0]
xs = 10
ys = 10
mdp = GridWorld(sx=xs, sy=ys, rs = rstates, rv = rvals)
qt = readdlm("grid-world-10x10-Q-matrix.txt")
ut = maximum(qt, 2)[:]
solver = ValueIterationSolver(verbose=true, init_util=ut, belres=1e-3)
policy = solve(solver, mdp)
return isapprox(ut, policy.util, rtol=1e-5)
end

function test_not_include_Q()
# Load correct policy from file and verify we can reconstruct it
rstates = [GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)]
rvals = [-10.0, -5.0, 10.0, 3.0]
xs = 10
ys = 10
mdp = GridWorld(sx=xs, sy=ys, rs = rstates, rv = rvals)
qt = readdlm("grid-world-10x10-Q-matrix.txt")
ut = maximum(qt, 2)[:]
niter = 100
res = 1e-3
solver = ValueIterationSolver(verbose=true, init_util=ut, belres=1e-3, include_Q=false)
policy = solve(solver, mdp)
return isapprox(ut, policy.util, rtol=1e-3)
end

function test_warning()
mdp = GridWorld()
solver = ValueIterationSolver()
println("There should be a warning bellow: ")
solve(solver, mdp, verbose=true)
end

@test test_complex_gridworld() == true
@test test_simple_grid() == true
@test test_init_solution() == true
@test test_not_include_Q() == true
test_warning()
Loading

0 comments on commit d4d56b8

Please sign in to comment.