Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backing off for now #31

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ NamedTupleTools = "0.11, 0.12, 0.13"
POMDPModelTools = "0.3.1"
POMDPTesting = "0.2.1"
POMDPs = "0.9"
Tricks = "0.1"
julia = "1"

[extras]
Expand Down
1 change: 0 additions & 1 deletion src/QuickPOMDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using POMDPTesting
using UUIDs
using NamedTupleTools
using Random
using Tricks: static_hasmethod

export
DiscreteExplicitPOMDP,
Expand Down
79 changes: 73 additions & 6 deletions src/quick.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ function QuickMDP(id=uuid4(); kwargs...)

S = infer_statetype(kwd)
A = infer_actiontype(kwd)

kwd =

d = namedtuple(keys(kwd)...)(values(kwd)...)
qm = QuickMDP{id, S, A, typeof(d)}(d)
return qm
Expand Down Expand Up @@ -223,9 +226,9 @@ end
function POMDPs.observation(m::QuickPOMDP, args...)
if haskey(m.data, :observation)
obs = m.data[:observation]
if static_hasmethod(obs, typeof(args))
if hasmethod(obs, typeof(args))
return obs(args...)
elseif length(args) == 3 && static_hasmethod(obs, typeof(args[2:3]))
elseif length(args) == 3 && hasmethod(obs, typeof(args[2:3]))
return obs(args[2:3]...)
else
return obs(args...)
Expand All @@ -239,15 +242,15 @@ end
function POMDPs.reward(m::QuickModel, args...)
if haskey(m.data, :reward)
r = m.data[:reward]
if static_hasmethod(r, typeof(args)) # static_hasmethod could cause issues, but I think it is worth doing in this single spot
if hasmethod(r, typeof(args)) # static_hasmethod could cause issues, but I think it is worth doing in this single spot
return r(args...)
elseif m isa POMDP && length(args) == 4
if static_hasmethod(r, typeof(args[1:3])) # (s, a, sp, o) -> (s, a, sp)
if hasmethod(r, typeof(args[1:3])) # (s, a, sp, o) -> (s, a, sp)
return r(args[1:3]...)
elseif static_hasmethod(r, typeof(args[1:2])) # (s, a, sp, o) -> (s, a)
elseif hasmethod(r, typeof(args[1:2])) # (s, a, sp, o) -> (s, a)
return r(args[1:2]...)
end
elseif length(args) == 3 && static_hasmethod(r, typeof(args[1:2])) # (s, a, sp) -> (s, a)
elseif length(args) == 3 && hasmethod(r, typeof(args[1:2])) # (s, a, sp) -> (s, a)
return r(args[1:2]...)
else
return r(args...)
Expand All @@ -257,6 +260,70 @@ function POMDPs.reward(m::QuickModel, args...)
end
end

struct QuickRewardModel{ArgNums, F} <: Function
f::F
hasmethod_fallback::Bool
end

QuickRewardModel(f::Function, S, A; hasmethod_fallback::Bool=true) = QuickRewardModel{reward_argnums(f, S, A), typeof(f)}(f, hasmethod_fallback)
QuickRewardModel(f::Function, S, A, O; hasmethod_fallback::Bool=true) = QuickRewardModel{reward_argnums(f, S, A, O), typeof(f)}(f, hasmethod_fallback)
QuickRewardModel(r::QuickRewardModel, args...) = r

function reward_argnums(f, S, A)
ans = []
if hasmethod(f, Tuple{S,A})
push!(ans, 2)
end
if hasmethod(f, Tuple{S,A,S})
push!(ans, 3)
end
return (ans...,) # convert to tuple
end

function reward_argnums(f, S, A, O)
if hasmethod(f, Tuple{S, A, S, O})
return (reward_argnums(f, S, A)..., 4)
else
return reward_argnums(f, S, A)
end
end

function (r::QuickRewardModel{ArgNums})(args...) where ArgNums
if length(args) in ArgNums
return r.f(args...)
elseif maximum(ArgNums) < length(args)
return r.f(args[1:maximum(ArgNums)]...)
elseif r.f.hasmethod_fallback
if hasmethod(r.f, typeof(args))
found = r.f(args...)
elseif m isa POMDP && length(args) == 4
if hasmethod(r.f, typeof(args[1:3])) # (s, a, sp, o) -> (s, a, sp)
found = r.f(args[1:3]...)
elseif hasmethod(r.f, typeof(args[1:2])) # (s, a, sp, o) -> (s, a)
found = r.f(args[1:2]...)
end
elseif length(args) == 3 && hasmethod(r.f, typeof(args[1:2])) # (s, a, sp) -> (s, a)
found = r.f(args[1:2]...)
else
return r.f(args...)
end
@warn("""A Quick(PO)MDP had to use hasmethod as a fallback to find the correct method of
the reward function to use.

This may be caused by adding new methods to the reward function after creating
the Quick(PO)MDP and can cause significant perfromance degredation. Originally,
the Quick(PO)MDP found reward methods with the following numbers of arguments:

$(ArgNums)

Recommend adding all methods to the reward function before creaing the
Quick(PO)MDP.""", current_methods=methods(r.f))
return found
else
return r.f(args...)
end
end

@forward_to_data POMDPs.initialstate
@forward_to_data POMDPs.initialobs

Expand Down