From 3464c39508ed422d7990d98dc213f7c430274343 Mon Sep 17 00:00:00 2001 From: Liam Kruse Date: Mon, 25 Jul 2022 12:44:53 -0700 Subject: [PATCH] upgraded to POMDPTools --- Project.toml | 8 +++----- README.md | 2 +- src/DeepQLearning.jl | 3 +-- src/policy.jl | 6 +++--- test/flux_test.jl | 2 +- test/prototype.jl | 2 +- test/runtests.jl | 2 +- test/test_env.jl | 2 +- 8 files changed, 12 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 7be3cee..381d6b1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DeepQLearning" uuid = "de0a67f4-c691-11e8-0034-5fc6e16e22d3" repo = "https://github.com/JuliaPOMDP/DeepQLearning.jl" -version = "0.6.4" +version = "0.6.5" [deps] BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" @@ -10,8 +10,7 @@ EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755" -POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415" -POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4" +POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -25,8 +24,7 @@ CommonRLInterface = "0.2, 0.3" EllipsisNotation = "0.4, 1.0" Flux = "0.10, 0.11, 0.12" POMDPLinter = "0.1" -POMDPModelTools = "0.3.4" -POMDPPolicies = "0.3, 0.4" +POMDPTools = "0.1" POMDPs = "0.9" Parameters = "0.12" StatsBase = "0.32, 0.33" diff --git a/README.md b/README.md index f289114..a4df914 100755 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ using POMDPs using Flux using POMDPModels using POMDPSimulators -using POMDPPolicies +using POMDPTools # load MDP model from POMDPModels or define your own! mdp = SimpleGridWorld(); diff --git a/src/DeepQLearning.jl b/src/DeepQLearning.jl index 0f6ee54..51bdeb8 100755 --- a/src/DeepQLearning.jl +++ b/src/DeepQLearning.jl @@ -6,8 +6,7 @@ using Printf using Parameters using Flux using BSON -using POMDPModelTools -using POMDPPolicies +using POMDPTools using POMDPLinter using LinearAlgebra using TensorBoardLogger: TBLogger, log_value diff --git a/src/policy.jl b/src/policy.jl index c74bb40..42d32ba 100755 --- a/src/policy.jl +++ b/src/policy.jl @@ -67,9 +67,9 @@ POMDPs.action(policy::NNPolicy, o) = _action(policy, o) POMDPs.action(policy::NNPolicy{P}, s) where {P <: MDP} = _action(policy, POMDPs.convert_s(Array{Float32}, s, policy.problem)) POMDPs.action(policy::NNPolicy{P}, o) where {P <: POMDP} = _action(policy, POMDPs.convert_o(Array{Float32}, o, policy.problem)) -POMDPPolicies.actionvalues(policy::NNPolicy, o) = _actionvalues(policy, o) -POMDPPolicies.actionvalues(policy::NNPolicy{P}, s) where {P<:MDP} = _actionvalues(policy, POMDPs.convert_s(Array{Float32}, s, policy.problem)) -POMDPPolicies.actionvalues(policy::NNPolicy{P}, o) where {P<:POMDP} = _actionvalues(policy, POMDPs.convert_o(Array{Float32}, o, policy.problem)) +POMDPTools.actionvalues(policy::NNPolicy, o) = _actionvalues(policy, o) +POMDPTools.actionvalues(policy::NNPolicy{P}, s) where {P<:MDP} = _actionvalues(policy, POMDPs.convert_s(Array{Float32}, s, policy.problem)) +POMDPTools.actionvalues(policy::NNPolicy{P}, o) where {P<:POMDP} = _actionvalues(policy, POMDPs.convert_o(Array{Float32}, o, policy.problem)) POMDPs.value(policy::NNPolicy, o) = _value(policy, o) POMDPs.value(policy::NNPolicy{P}, s) where {P <: MDP} = _value(policy, POMDPs.convert_s(Array{Float32}, s, policy.problem)) diff --git a/test/flux_test.jl b/test/flux_test.jl index dc3c289..9af3d94 100644 --- a/test/flux_test.jl +++ b/test/flux_test.jl @@ -4,8 +4,8 @@ using Random using DeepQLearning using POMDPModels using POMDPSimulators +using POMDPTools using RLInterface -using POMDPPolicies using Test using Flux diff --git a/test/prototype.jl b/test/prototype.jl index a26f175..89dbb5b 100644 --- a/test/prototype.jl +++ b/test/prototype.jl @@ -2,7 +2,7 @@ using Revise using Random using BenchmarkTools using POMDPs -using POMDPModelTools +using POMDPTools # using CuArrays using Flux using DeepQLearning diff --git a/test/runtests.jl b/test/runtests.jl index b6b079a..293dd89 100755 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,7 @@ using DeepQLearning using POMDPModels using POMDPSimulators -using POMDPPolicies +using POMDPTools using Flux using Random using StaticArrays diff --git a/test/test_env.jl b/test/test_env.jl index b5053e9..22580e0 100755 --- a/test/test_env.jl +++ b/test/test_env.jl @@ -1,5 +1,5 @@ using POMDPs -using POMDPModelTools +using POMDPTools # Define a test environment # it has 2 states, it ends up after taking 5 action