Skip to content

Commit

Permalink
Assume action/obs_space is non-empty
Browse files Browse the repository at this point in the history
  • Loading branch information
dantp-ai committed Jan 29, 2024
1 parent 5df51ec commit 667b4be
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions tianshou/utils/space_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ class ActionSpaceInfo(ToStringMixin):

@property
def action_dim(self) -> int:
"""Return the number of distinct actions an agent can take it its action space."""
"""Return the number of distinct actions (must be non-empty) an agent can take it its action space."""
if isinstance(self.action_shape, int):
return self.action_shape
elif isinstance(self.action_shape, Sequence) and self.action_shape:
return int(np.prod(self.action_shape))
else:
raise ValueError("Invalid action_shape: {self.action_shape}.")
return int(np.prod(self.action_shape))

@classmethod
def from_space(cls, space: spaces.Space) -> Self:
Expand Down Expand Up @@ -62,13 +60,11 @@ class ObservationSpaceInfo(ToStringMixin):

@property
def obs_dim(self) -> int:
"""Return the number of distinct features or dimensions in the observation space."""
"""Return the number of distinct features (must be non-empty) or dimensions in the observation space."""
if isinstance(self.obs_shape, int):
return self.obs_shape
elif isinstance(self.obs_shape, Sequence) and self.obs_shape:
return int(np.prod(self.obs_shape))
else:
raise ValueError("Invalid obs_shape: {self.obs_shape}.")
return int(np.prod(self.obs_shape))

@classmethod
def from_space(cls, space: spaces.Space) -> Self:
Expand Down

0 comments on commit 667b4be

Please sign in to comment.