diff --git a/src/imitation/policies/obs_update_wrapper.py b/src/imitation/policies/obs_update_wrapper.py index 93f2a57ca..4c79e694a 100644 --- a/src/imitation/policies/obs_update_wrapper.py +++ b/src/imitation/policies/obs_update_wrapper.py @@ -103,7 +103,7 @@ def _update_ob( def _remove_hr_obs( obs: Union[np.ndarray, Dict[str, np.ndarray]], ) -> Union[np.ndarray, Dict[str, np.ndarray]]: - """Removes rgb observation from the observation.""" + """Removes the human readable observation if any.""" if not isinstance(obs, dict): return obs if data_wrappers.HR_OBS_KEY not in obs: diff --git a/tests/policies/test_obs_update_wrapper.py b/tests/policies/test_obs_update_wrapper.py index cdc79ad28..2f18c7d4e 100644 --- a/tests/policies/test_obs_update_wrapper.py +++ b/tests/policies/test_obs_update_wrapper.py @@ -47,7 +47,7 @@ def test_remove_hr(use_hr_wrapper: bool): {"a": np.array([1])}, ), ( - "dict rgb removed successfully and got unwrapped from dict", + "dict hr removed successfully and got unwrapped from dict", { "a": np.array([1]), HR_OBS_KEY: np.array([3]), @@ -55,7 +55,7 @@ def test_remove_hr(use_hr_wrapper: bool): np.array([1]), ), ( - "dict rgb removed successfully and got dict", + "dict hr removed successfully and got dict", { "a": np.array([1]), "b": np.array([2]), @@ -68,7 +68,7 @@ def test_remove_hr(use_hr_wrapper: bool): ), ], ) -def test_remove_rgb_ob(testname, obs, expected_obs): +def test_remove_hr_ob(testname, obs, expected_obs): got_obs = _remove_hr_obs(obs) assert type(got_obs) is type(expected_obs) if isinstance(got_obs, (Dict, gym.spaces.Dict)): @@ -79,12 +79,12 @@ def test_remove_rgb_ob(testname, obs, expected_obs): assert got_obs == expected_obs -def test_remove_rgb_obs_failure(): +def test_remove_hr_obs_failure(): with pytest.raises(ValueError, match="Only human readable observation*"): _remove_hr_obs({HR_OBS_KEY: np.array([1])}) -def test_remove_rgb_obs_still_keep_origin_space_rgb(): +def test_remove_hr_obs_still_keep_origin_space_rgb(): obs = {"a": np.array([1]), HR_OBS_KEY: np.array([2])} _remove_hr_obs(obs) assert HR_OBS_KEY in obs