diff --git a/smarts/env/utils/observation_conversion.py b/smarts/env/utils/observation_conversion.py index 9de20ff3cd..1c8772e5a4 100644 --- a/smarts/env/utils/observation_conversion.py +++ b/smarts/env/utils/observation_conversion.py @@ -19,10 +19,12 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +from __future__ import annotations + import math import warnings from enum import IntEnum -from functools import cached_property +from functools import cached_property, partial from typing import Any, Callable, Dict, List, Optional, Tuple import gymnasium as gym @@ -182,13 +184,34 @@ def _format_signals(signals: List[SignalObservation]): } -def _format_neighborhood_vehicle_states( +def _format_neighborhood_vehicle_states_option_lane_position( neighborhood_vehicle_states: Tuple[VehicleObservation], ): des_shp = _NEIGHBOR_SHP rcv_shp = len(neighborhood_vehicle_states) pad_shp = 0 if des_shp - rcv_shp < 0 else des_shp - rcv_shp + if rcv_shp == 0: + return np.zeros((des_shp, 3), dtype=np.float32) + + lane_positions = [ + nghb.lane_position for nghb in neighborhood_vehicle_states[:des_shp] + ] + lane_positions = np.array(lane_positions, dtype=np.float32) + lane_positions = np.pad( + lane_positions, ((0, pad_shp), (0, 0)), mode="constant", constant_values=0 + ) + return lane_positions + + +def _format_neighborhood_vehicle_states( + neighborhood_vehicle_states: Tuple[VehicleObservation], + agent_interface: Optional[AgentInterface], +) -> Dict[str, Any]: + des_shp = _NEIGHBOR_SHP + rcv_shp = len(neighborhood_vehicle_states) + pad_shp = 0 if des_shp - rcv_shp < 0 else des_shp - rcv_shp + if rcv_shp == 0: return { "box": np.zeros((des_shp, 3), dtype=np.float32), @@ -236,7 +259,7 @@ def _format_neighborhood_vehicle_states( interest = np.pad(interest, ((0,pad_shp)), mode="constant", constant_values=False) # fmt: on - return { + output = { "box": box, "heading": heading, "id": vehicle_id, @@ -246,6 +269,44 @@ def _format_neighborhood_vehicle_states( "position": pos, "speed": speed, } + if agent_interface is not None and agent_interface.lane_positions: + output[ + "lane_position" + ] = _format_neighborhood_vehicle_states_option_lane_position( + neighborhood_vehicle_states + ) + + return output + + +def _configure_neighborhood_vehicle_states_space(agent_interface: AgentInterface): + sub_spaces = { + "box": gym.spaces.Box( + low=0, high=1e10, shape=(_NEIGHBOR_SHP, 3), dtype=np.float32 + ), + "heading": gym.spaces.Box( + low=-math.pi, high=math.pi, shape=(_NEIGHBOR_SHP,), dtype=np.float32 + ), + "id": gym.spaces.Tuple( + (gym.spaces.Text(_ID_NAME_LIMIT, charset=_WAYPOINT_CHAR_SET),) + * _NEIGHBOR_SHP + ), + "lane_index": gym.spaces.Box( + low=0, high=127, shape=(_NEIGHBOR_SHP,), dtype=np.int8 + ), + "position": gym.spaces.Box( + low=-1e10, high=1e10, shape=(_NEIGHBOR_SHP, 3), dtype=np.float64 + ), + "speed": gym.spaces.Box( + low=0, high=1e10, shape=(_NEIGHBOR_SHP,), dtype=np.float32 + ), + "interest": gym.spaces.MultiBinary(_NEIGHBOR_SHP), + } + if agent_interface.lane_positions: + sub_spaces["lane_position"] = gym.spaces.Box( + low=-1e10, high=1e10, shape=(_NEIGHBOR_SHP, 3), dtype=np.float32 + ) + return gym.spaces.Dict(sub_spaces) def _format_lidar( @@ -357,7 +418,9 @@ class StandardConfigurableSpaceFormat(BaseSpaceFormat): def __init__( self, - formatting_func: Callable[[Observation], Dict[str, Any]], + formatting_func: Callable[ + [Observation, Optional[AgentInterface]], Dict[str, Any] + ], active_func: Callable[[AgentInterface], bool], name: str, space_func: Callable[[AgentInterface], gym.Space], @@ -372,7 +435,7 @@ def __init__( def format(self, obs: Observation): """Selects and formats the given observation to get a value that matches the :attr:`space`.""" - return self._formatting_func(obs) + return self._formatting_func(obs, self._agent_interface) def active(self, agent_interface: AgentInterface) -> bool: """If this formatting is active and should be included in the output.""" @@ -563,6 +626,13 @@ def __call__(self, agent_interface: AgentInterface) -> BaseSpaceFormat: _VEC3_SIGNED_FLOAT32_SPACE, ) +ego_lane_position_format = StandardSpaceFormat( + lambda obs: np.array(obs.ego_vehicle_state.lane_position, dtype=np.float32), + lambda agent_interface: bool(agent_interface.lane_positions), + "lane_position", + _VEC3_SIGNED_FLOAT32_SPACE, +) + mission_space_format = StandardSpaceFormat( lambda obs: _format_mission(obs.ego_vehicle_state.mission), lambda _: True, @@ -673,7 +743,7 @@ def name(self): lidar_point_cloud_space_format = StandardConfigurableSpaceFormat( - lambda obs: _format_lidar(obs.lidar_point_cloud), + lambda obs, agent_interface: _format_lidar(obs.lidar_point_cloud), lambda agent_interface: bool(agent_interface.lidar_point_cloud), "lidar_point_cloud", # MTA TODO: add lidar configuration @@ -694,34 +764,11 @@ def name(self): ) -neighborhood_vehicle_states_space_format = StandardSpaceFormat( +neighborhood_vehicle_states_space_format = StandardConfigurableSpaceFormat( lambda obs: _format_neighborhood_vehicle_states(obs.neighborhood_vehicle_states), lambda agent_interface: bool(agent_interface.neighborhood_vehicle_states), "neighborhood_vehicle_states", - gym.spaces.Dict( - { - "box": gym.spaces.Box( - low=0, high=1e10, shape=(_NEIGHBOR_SHP, 3), dtype=np.float32 - ), - "heading": gym.spaces.Box( - low=-math.pi, high=math.pi, shape=(_NEIGHBOR_SHP,), dtype=np.float32 - ), - "id": gym.spaces.Tuple( - (gym.spaces.Text(_ID_NAME_LIMIT, charset=_WAYPOINT_CHAR_SET),) - * _NEIGHBOR_SHP - ), - "lane_index": gym.spaces.Box( - low=0, high=127, shape=(_NEIGHBOR_SHP,), dtype=np.int8 - ), - "position": gym.spaces.Box( - low=-1e10, high=1e10, shape=(_NEIGHBOR_SHP, 3), dtype=np.float64 - ), - "speed": gym.spaces.Box( - low=0, high=1e10, shape=(_NEIGHBOR_SHP,), dtype=np.float32 - ), - "interest": gym.spaces.MultiBinary(_NEIGHBOR_SHP), - } - ), + _configure_neighborhood_vehicle_states_space, ) @@ -849,6 +896,7 @@ def name(self): # optional ego_angular_acceleration_space_format, ego_angular_jerk_space_format, + ego_lane_position_format, ego_linear_acceleration_space_format, ego_linear_jerk_space_format, ], @@ -966,6 +1014,11 @@ class ObservationSpacesFormatter: "yaw_rate": Rotation speed around vertical axis in rad/s [0, 2pi]. dtype=np.float32. + "lane_position": + A reference line coordinate. Coordinates are s, t, and h relating + to lane offset along lane, horizontal displacement and surface + displacement. + shape=(3,). dtype=np.float64 )} A dictionary of event markers.