Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  * Include properties: action_dim, obs_dim in string repr
  • Loading branch information
dantp-ai committed Jan 26, 2024
1 parent e1f53dc commit 07eee1d
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions tianshou/utils/space_info.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Self
from typing import Any, Self

import numpy as np
from gymnasium import spaces

from tianshou.utils.string import ToStringMixin


@dataclass(kw_only=True)
class ActionSpaceInfo:
class ActionSpaceInfo(ToStringMixin):
"""A data structure for storing the different attributes of the action space."""

action_shape: int | Sequence[int]
"""The shape of the action space."""
min_action: float
"""The smallest allowable action or in the continuous case the lower bound for allowable action value."""
max_action: float
"""The largest allowable action or in the continuous case the upper bound for allowable action value."""

@property
def action_dim(self) -> int:
"""Return the number of distinct actions an agent can take it its action space."""
if isinstance(self.action_shape, int):
return self.action_shape
elif isinstance(self.action_shape, Sequence) and self.action_shape:
Expand All @@ -23,6 +31,7 @@ def action_dim(self) -> int:

@classmethod
def from_space(cls, space: spaces.Space) -> Self:
"""Return the attributes of the action space based on the instance type of the space."""
if isinstance(space, spaces.Box):
return cls(
action_shape=space.shape,
Expand All @@ -40,13 +49,20 @@ def from_space(cls, space: spaces.Space) -> Self:
f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.",
)

def _tostring_additional_entries(self) -> dict[str, Any]:
return {"action_dim": self.action_dim}


@dataclass(kw_only=True)
class ObservationSpaceInfo:
class ObservationSpaceInfo(ToStringMixin):
"""A data structure for storing the different attributes of the observation space."""

obs_shape: int | Sequence[int]
"""The shape of the observation space."""

@property
def obs_dim(self) -> int:
"""Return the number of distinct features or dimensions in the observation space."""
if isinstance(self.obs_shape, int):
return self.obs_shape
elif isinstance(self.obs_shape, Sequence) and self.obs_shape:
Expand All @@ -56,6 +72,7 @@ def obs_dim(self) -> int:

@classmethod
def from_space(cls, space: spaces.Space) -> Self:
"""Return the attributes of the observation space based on the instance type of the space."""
if isinstance(space, spaces.Box):
return cls(
obs_shape=space.shape,
Expand All @@ -69,14 +86,22 @@ def from_space(cls, space: spaces.Space) -> Self:
f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.",
)

def _tostring_additional_entries(self) -> dict[str, Any]:
return {"obs_dim": self.obs_dim}


@dataclass(kw_only=True)
class SpaceInfo:
class SpaceInfo(ToStringMixin):
"""A data structure for storing the attributes of both the action and observation space."""

action_info: ActionSpaceInfo
"""Stores the attributes of the action space."""
observation_info: ObservationSpaceInfo
"""Stores the attributes of the observation space."""

@classmethod
def from_env(cls, action_space: spaces.Space, observation_space: spaces.Space) -> Self:
"""Return the attributes of the action and observation space based on the instance type of each of the spaces."""
action_info = ActionSpaceInfo.from_space(action_space)
observation_info = ObservationSpaceInfo.from_space(observation_space)

Expand Down

0 comments on commit 07eee1d

Please sign in to comment.