-
Notifications
You must be signed in to change notification settings - Fork 12
/
sparse.jl
122 lines (111 loc) · 4.17 KB
/
sparse.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
struct SparseValueIterationSolver <: Solver
max_iterations::Int64
belres::Float64 # the Bellman Residual
include_Q::Bool
verbose::Bool
init_util::Vector{Float64}
end
function SparseValueIterationSolver(;max_iterations=500,
belres::Float64=1e-3,
include_Q::Bool=true,
verbose::Bool=false,
init_util::Vector{Float64}=Vector{Float64}(undef, 0))
return SparseValueIterationSolver(max_iterations, belres, include_Q, verbose, init_util)
end
@POMDP_require solve(solver::SparseValueIterationSolver, mdp::MDP) begin
P = typeof(mdp)
S = statetype(P)
A = actiontype(P)
@req discount(::P)
@subreq ordered_states(mdp)
@subreq ordered_actions(mdp)
@req transition(::P,::S,::A)
@req reward(::P,::S,::A,::S)
@req stateindex(::P,::S)
@req actionindex(::P, ::A)
@req actions(::P, ::S)
as = actions(mdp)
ss = states(mdp)
@req length(::typeof(ss))
@req length(::typeof(as))
a = first(as)
s = first(ss)
dist = transition(mdp, s, a)
D = typeof(dist)
@req support(::D)
@req pdf(::D,::S)
@subreq SparseTabularMDP(mdp)
end
function qvalue!(m::Union{MDP,POMDP}, transition_A_S_S2, reward_S_A::AbstractMatrix{F}, value_S::AbstractVector{F}, out_qvals_S_A) where {F}
@assert size(out_qvals_S_A) == (length(states(m)), length(actions(m)))
for a in 1:length(actions(m))
out_qvals_S_A[:, a] = view(reward_S_A, :, a) + discount(m) * transition_A_S_S2[a] * value_S
end
end
function qvalue!(m::MDP, transition_A_S_S2, reward_S_A, value_S, out_qvals_S_A, _mul_cache)
@assert size(out_qvals_S_A) == (length(states(m)), length(actions(m)))
γ = discount(m)
for a in 1:length(actions(m))
Vp = mul!(_mul_cache, transition_A_S_S2[a], value_S)
out_qvals_S_A[:, a] .= view(reward_S_A, :, a) .+ γ .* Vp
end
end
function _value!(V, q_vals_S_A)
δ_max = 0.0
for i ∈ 1:size(q_vals_S_A, 1)
vp = maximum(@view q_vals_S_A[i,:])
δ = abs(V[i] - vp)
δ > δ_max && (δ_max = δ)
V[i] = vp
end
return δ_max
end
function solve(solver::SparseValueIterationSolver, mdp::SparseTabularMDP)
nS = length(states(mdp))
nA = length(actions(mdp))
if isempty(solver.init_util)
v_S = zeros(nS)
else
@assert length(solver.init_util) == nS "Input utility dimension mismatch"
v_S = solver.init_util
end
_mul_cache = similar(v_S)
transition_A_S_S2 = transition_matrices(mdp)
reward_S_A = reward_matrix(mdp)
qvals_S_A = zeros(nS, nA)
total_time = 0.0
for i in 1:solver.max_iterations
iter_time = @elapsed begin
qvalue!(mdp, transition_A_S_S2, reward_S_A, v_S, qvals_S_A, _mul_cache)
δ = _value!(v_S, qvals_S_A)
end
total_time += iter_time
if solver.verbose
@info "residual: $(δ), time: $(iter_time), total time: $(total_time) " i
end
δ < solver.belres && break
end
qvalue!(mdp, transition_A_S_S2, reward_S_A, v_S, qvals_S_A, _mul_cache)
policy_S = argmax.(eachrow(qvals_S_A))
return if solver.include_Q
ValueIterationPolicy(mdp, qvals_S_A, v_S, policy_S)
else
ValueIterationPolicy(mdp, utility=v_S, policy=policy_S, include_Q=false)
end
end
function solve(solver::SparseValueIterationSolver, mdp::MDP)
p = solve(solver, SparseTabularMDP(mdp))
return ValueIterationPolicy(p.qmat, p.util, p.policy, ordered_actions(mdp), p.include_Q, mdp)
end
function solve(::SparseValueIterationSolver, ::POMDP)
throw("""
ValueIterationError: `solve(::SparseValueIterationSolver, ::POMDP)` is not supported,
`SparseValueIterationSolver` supports MDP models only, look at QMDP.jl for a POMDP solver that assumes full observability.
If you still wish to use the transition and reward from your POMDP model you can use the `UnderlyingMDP` wrapper from POMDPModelTools.jl as follows:
```
solver = ValueIterationSolver()
mdp = UnderlyingMDP(pomdp)
solve(solver, mdp)
```
""")
end