Skip to content

Commit

Permalink
rgb to hr
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueWang25 committed Oct 6, 2023
1 parent 125d19d commit f8ebbc4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/imitation/policies/obs_update_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _update_ob(
def _remove_hr_obs(
obs: Union[np.ndarray, Dict[str, np.ndarray]],
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""Removes rgb observation from the observation."""
"""Removes the human readable observation if any."""
if not isinstance(obs, dict):
return obs
if data_wrappers.HR_OBS_KEY not in obs:
Expand Down
10 changes: 5 additions & 5 deletions tests/policies/test_obs_update_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ def test_remove_hr(use_hr_wrapper: bool):
{"a": np.array([1])},
),
(
"dict rgb removed successfully and got unwrapped from dict",
"dict hr removed successfully and got unwrapped from dict",
{
"a": np.array([1]),
HR_OBS_KEY: np.array([3]),
},
np.array([1]),
),
(
"dict rgb removed successfully and got dict",
"dict hr removed successfully and got dict",
{
"a": np.array([1]),
"b": np.array([2]),
Expand All @@ -68,7 +68,7 @@ def test_remove_hr(use_hr_wrapper: bool):
),
],
)
def test_remove_rgb_ob(testname, obs, expected_obs):
def test_remove_hr_ob(testname, obs, expected_obs):
got_obs = _remove_hr_obs(obs)
assert type(got_obs) is type(expected_obs)
if isinstance(got_obs, (Dict, gym.spaces.Dict)):
Expand All @@ -79,12 +79,12 @@ def test_remove_rgb_ob(testname, obs, expected_obs):
assert got_obs == expected_obs


def test_remove_rgb_obs_failure():
def test_remove_hr_obs_failure():
with pytest.raises(ValueError, match="Only human readable observation*"):
_remove_hr_obs({HR_OBS_KEY: np.array([1])})


def test_remove_rgb_obs_still_keep_origin_space_rgb():
def test_remove_hr_obs_still_keep_origin_space_rgb():
obs = {"a": np.array([1]), HR_OBS_KEY: np.array([2])}
_remove_hr_obs(obs)
assert HR_OBS_KEY in obs

0 comments on commit f8ebbc4

Please sign in to comment.