Skip to content

Commit

Permalink
Fix missing observation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Feb 28, 2024
1 parent e38fbca commit a344214
Showing 1 changed file with 84 additions and 31 deletions.
115 changes: 84 additions & 31 deletions smarts/env/utils/observation_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from __future__ import annotations

import math
import warnings
from enum import IntEnum
from functools import cached_property
from functools import cached_property, partial
from typing import Any, Callable, Dict, List, Optional, Tuple

import gymnasium as gym
Expand Down Expand Up @@ -182,13 +184,34 @@ def _format_signals(signals: List[SignalObservation]):
}


def _format_neighborhood_vehicle_states(
def _format_neighborhood_vehicle_states_option_lane_position(
neighborhood_vehicle_states: Tuple[VehicleObservation],
):
des_shp = _NEIGHBOR_SHP
rcv_shp = len(neighborhood_vehicle_states)
pad_shp = 0 if des_shp - rcv_shp < 0 else des_shp - rcv_shp

if rcv_shp == 0:
return np.zeros((des_shp, 3), dtype=np.float32)

lane_positions = [
nghb.lane_position for nghb in neighborhood_vehicle_states[:des_shp]
]
lane_positions = np.array(lane_positions, dtype=np.float32)
lane_positions = np.pad(
lane_positions, ((0, pad_shp), (0, 0)), mode="constant", constant_values=0
)
return lane_positions


def _format_neighborhood_vehicle_states(
neighborhood_vehicle_states: Tuple[VehicleObservation],
agent_interface: Optional[AgentInterface],
) -> Dict[str, Any]:
des_shp = _NEIGHBOR_SHP
rcv_shp = len(neighborhood_vehicle_states)
pad_shp = 0 if des_shp - rcv_shp < 0 else des_shp - rcv_shp

if rcv_shp == 0:
return {
"box": np.zeros((des_shp, 3), dtype=np.float32),
Expand Down Expand Up @@ -236,7 +259,7 @@ def _format_neighborhood_vehicle_states(
interest = np.pad(interest, ((0,pad_shp)), mode="constant", constant_values=False)
# fmt: on

return {
output = {
"box": box,
"heading": heading,
"id": vehicle_id,
Expand All @@ -246,6 +269,44 @@ def _format_neighborhood_vehicle_states(
"position": pos,
"speed": speed,
}
if agent_interface is not None and agent_interface.lane_positions:
output[
"lane_position"
] = _format_neighborhood_vehicle_states_option_lane_position(
neighborhood_vehicle_states
)

return output


def _configure_neighborhood_vehicle_states_space(agent_interface: AgentInterface):
sub_spaces = {
"box": gym.spaces.Box(
low=0, high=1e10, shape=(_NEIGHBOR_SHP, 3), dtype=np.float32
),
"heading": gym.spaces.Box(
low=-math.pi, high=math.pi, shape=(_NEIGHBOR_SHP,), dtype=np.float32
),
"id": gym.spaces.Tuple(
(gym.spaces.Text(_ID_NAME_LIMIT, charset=_WAYPOINT_CHAR_SET),)
* _NEIGHBOR_SHP
),
"lane_index": gym.spaces.Box(
low=0, high=127, shape=(_NEIGHBOR_SHP,), dtype=np.int8
),
"position": gym.spaces.Box(
low=-1e10, high=1e10, shape=(_NEIGHBOR_SHP, 3), dtype=np.float64
),
"speed": gym.spaces.Box(
low=0, high=1e10, shape=(_NEIGHBOR_SHP,), dtype=np.float32
),
"interest": gym.spaces.MultiBinary(_NEIGHBOR_SHP),
}
if agent_interface.lane_positions:
sub_spaces["lane_position"] = gym.spaces.Box(
low=-1e10, high=1e10, shape=(_NEIGHBOR_SHP, 3), dtype=np.float32
)
return gym.spaces.Dict(sub_spaces)


def _format_lidar(
Expand Down Expand Up @@ -357,7 +418,9 @@ class StandardConfigurableSpaceFormat(BaseSpaceFormat):

def __init__(
self,
formatting_func: Callable[[Observation], Dict[str, Any]],
formatting_func: Callable[
[Observation, Optional[AgentInterface]], Dict[str, Any]
],
active_func: Callable[[AgentInterface], bool],
name: str,
space_func: Callable[[AgentInterface], gym.Space],
Expand All @@ -372,7 +435,7 @@ def __init__(

def format(self, obs: Observation):
"""Selects and formats the given observation to get a value that matches the :attr:`space`."""
return self._formatting_func(obs)
return self._formatting_func(obs, self._agent_interface)

def active(self, agent_interface: AgentInterface) -> bool:
"""If this formatting is active and should be included in the output."""
Expand Down Expand Up @@ -563,6 +626,13 @@ def __call__(self, agent_interface: AgentInterface) -> BaseSpaceFormat:
_VEC3_SIGNED_FLOAT32_SPACE,
)

ego_lane_position_format = StandardSpaceFormat(
lambda obs: np.array(obs.ego_vehicle_state.lane_position, dtype=np.float32),
lambda agent_interface: bool(agent_interface.lane_positions),
"lane_position",
_VEC3_SIGNED_FLOAT32_SPACE,
)

mission_space_format = StandardSpaceFormat(
lambda obs: _format_mission(obs.ego_vehicle_state.mission),
lambda _: True,
Expand Down Expand Up @@ -673,7 +743,7 @@ def name(self):


lidar_point_cloud_space_format = StandardConfigurableSpaceFormat(
lambda obs: _format_lidar(obs.lidar_point_cloud),
lambda obs, agent_interface: _format_lidar(obs.lidar_point_cloud),
lambda agent_interface: bool(agent_interface.lidar_point_cloud),
"lidar_point_cloud",
# MTA TODO: add lidar configuration
Expand All @@ -694,34 +764,11 @@ def name(self):
)


neighborhood_vehicle_states_space_format = StandardSpaceFormat(
neighborhood_vehicle_states_space_format = StandardConfigurableSpaceFormat(
lambda obs: _format_neighborhood_vehicle_states(obs.neighborhood_vehicle_states),
lambda agent_interface: bool(agent_interface.neighborhood_vehicle_states),
"neighborhood_vehicle_states",
gym.spaces.Dict(
{
"box": gym.spaces.Box(
low=0, high=1e10, shape=(_NEIGHBOR_SHP, 3), dtype=np.float32
),
"heading": gym.spaces.Box(
low=-math.pi, high=math.pi, shape=(_NEIGHBOR_SHP,), dtype=np.float32
),
"id": gym.spaces.Tuple(
(gym.spaces.Text(_ID_NAME_LIMIT, charset=_WAYPOINT_CHAR_SET),)
* _NEIGHBOR_SHP
),
"lane_index": gym.spaces.Box(
low=0, high=127, shape=(_NEIGHBOR_SHP,), dtype=np.int8
),
"position": gym.spaces.Box(
low=-1e10, high=1e10, shape=(_NEIGHBOR_SHP, 3), dtype=np.float64
),
"speed": gym.spaces.Box(
low=0, high=1e10, shape=(_NEIGHBOR_SHP,), dtype=np.float32
),
"interest": gym.spaces.MultiBinary(_NEIGHBOR_SHP),
}
),
_configure_neighborhood_vehicle_states_space,
)


Expand Down Expand Up @@ -849,6 +896,7 @@ def name(self):
# optional
ego_angular_acceleration_space_format,
ego_angular_jerk_space_format,
ego_lane_position_format,
ego_linear_acceleration_space_format,
ego_linear_jerk_space_format,
],
Expand Down Expand Up @@ -966,6 +1014,11 @@ class ObservationSpacesFormatter:
"yaw_rate":
Rotation speed around vertical axis in rad/s [0, 2pi].
dtype=np.float32.
"lane_position":
A reference line coordinate. Coordinates are s, t, and h relating
to lane offset along lane, horizontal displacement and surface
displacement.
shape=(3,). dtype=np.float64
)}
A dictionary of event markers.
Expand Down

0 comments on commit a344214

Please sign in to comment.