-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
93 add test for resetting the poses in vectorized env #94
Changes from all commits
1df8d2c
9b536b2
d3ca68a
e08e66b
fcbbf83
21d8193
a732f4a
47e64f8
22b4c9b
465dbf2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from f110_gym.envs.reset.masked_reset import GridResetFn, AllTrackResetFn | ||
from f110_gym.envs.reset.reset_fn import ResetFn | ||
from f110_gym.envs.track import Track | ||
|
||
|
||
def make_reset_fn(type: str, track: Track, num_agents: int, **kwargs) -> ResetFn: | ||
if type == "grid_static": | ||
return GridResetFn(track=track, num_agents=num_agents, shuffle=False, **kwargs) | ||
elif type == "grid_random": | ||
return GridResetFn(track=track, num_agents=num_agents, shuffle=True, **kwargs) | ||
elif type == "random_static": | ||
return AllTrackResetFn( | ||
track=track, num_agents=num_agents, shuffle=False, **kwargs | ||
) | ||
elif type == "random_random": | ||
return AllTrackResetFn( | ||
track=track, num_agents=num_agents, shuffle=True, **kwargs | ||
) | ||
else: | ||
raise ValueError(f"invalid reset type {type}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from abc import abstractmethod | ||
|
||
import numpy as np | ||
|
||
from f110_gym.envs.reset.reset_fn import ResetFn | ||
from f110_gym.envs.reset.utils import sample_around_waypoint | ||
from f110_gym.envs.track import Track | ||
|
||
|
||
class MaskedResetFn(ResetFn): | ||
@abstractmethod | ||
def get_mask(self) -> np.ndarray: | ||
pass | ||
|
||
def __init__( | ||
self, | ||
track: Track, | ||
num_agents: int, | ||
move_laterally: bool, | ||
min_dist: float, | ||
max_dist: float, | ||
): | ||
self.track = track | ||
self.n_agents = num_agents | ||
self.min_dist = min_dist | ||
self.max_dist = max_dist | ||
self.move_laterally = move_laterally | ||
self.mask = self.get_mask() | ||
|
||
def sample(self) -> np.ndarray: | ||
waypoint_id = np.random.choice(np.where(self.mask)[0]) | ||
poses = sample_around_waypoint( | ||
track=self.track, | ||
waypoint_id=waypoint_id, | ||
n_agents=self.n_agents, | ||
min_dist=self.min_dist, | ||
max_dist=self.max_dist, | ||
move_laterally=self.move_laterally, | ||
) | ||
return poses | ||
|
||
|
||
class GridResetFn(MaskedResetFn): | ||
def __init__( | ||
self, | ||
track: Track, | ||
num_agents: int, | ||
move_laterally: bool = True, | ||
shuffle: bool = True, | ||
start_width: float = 1.0, | ||
min_dist: float = 1.5, | ||
max_dist: float = 2.5, | ||
): | ||
self.start_width = start_width | ||
self.shuffle = shuffle | ||
|
||
super().__init__( | ||
track=track, | ||
num_agents=num_agents, | ||
move_laterally=move_laterally, | ||
min_dist=min_dist, | ||
max_dist=max_dist, | ||
) | ||
|
||
def get_mask(self) -> np.ndarray: | ||
# approximate the nr waypoints in the starting line | ||
step_size = self.track.centerline.length / self.track.centerline.n | ||
n_wps = int(self.start_width / step_size) | ||
|
||
mask = np.zeros(self.track.centerline.n) | ||
mask[:n_wps] = 1 | ||
return mask.astype(bool) | ||
|
||
def sample(self) -> np.ndarray: | ||
poses = super().sample() | ||
|
||
if self.shuffle: | ||
np.random.shuffle(poses) | ||
|
||
return poses | ||
|
||
|
||
class AllTrackResetFn(MaskedResetFn): | ||
def __init__( | ||
self, | ||
track: Track, | ||
num_agents: int, | ||
move_laterally: bool = True, | ||
shuffle: bool = True, | ||
min_dist: float = 1.5, | ||
max_dist: float = 2.5, | ||
): | ||
super().__init__( | ||
track=track, | ||
num_agents=num_agents, | ||
move_laterally=move_laterally, | ||
min_dist=min_dist, | ||
max_dist=max_dist, | ||
) | ||
self.shuffle = shuffle | ||
|
||
def get_mask(self) -> np.ndarray: | ||
return np.ones(self.track.centerline.n).astype(bool) | ||
|
||
def sample(self) -> np.ndarray: | ||
poses = super().sample() | ||
|
||
if self.shuffle: | ||
np.random.shuffle(poses) | ||
|
||
return poses |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from abc import abstractmethod | ||
|
||
import numpy as np | ||
|
||
|
||
class ResetFn: | ||
@abstractmethod | ||
def sample(self) -> np.ndarray: | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
from f110_gym.envs.track import Track | ||
|
||
|
||
def sample_around_waypoint( | ||
track: Track, | ||
waypoint_id: int, | ||
n_agents: int, | ||
min_dist: float, | ||
max_dist: float, | ||
move_laterally: bool = True, | ||
) -> np.ndarray: | ||
""" | ||
Compute n poses around a given waypoint in the track. | ||
It iteratively samples the next agent within a distance range from the previous one. | ||
|
||
Args: | ||
- track: the track object | ||
- waypoint_id: the id of the first waypoint from which start the sampling | ||
- n_agents: the number of agents | ||
- min_dist: the minimum distance between two consecutive agents | ||
- max_dist: the maximum distance between two consecutive agents | ||
- move_laterally: if True, the agents are sampled on the left/right of the track centerline | ||
""" | ||
current_wp_id = waypoint_id | ||
n_waypoints = track.centerline.n | ||
|
||
poses = [] | ||
rnd_sign = ( | ||
np.random.choice([-1, 1]) if move_laterally else 1 | ||
) # random sign to sample lateral position (left/right) | ||
for i in range(n_agents): | ||
# compute pose from current wp_id | ||
wp = [ | ||
track.centerline.xs[current_wp_id], | ||
track.centerline.ys[current_wp_id], | ||
] | ||
next_wp_id = (current_wp_id + 1) % n_waypoints | ||
next_wp = [ | ||
track.centerline.xs[next_wp_id], | ||
track.centerline.ys[next_wp_id], | ||
] | ||
theta = np.arctan2(next_wp[1] - wp[1], next_wp[0] - wp[0]) | ||
|
||
x, y = wp[0], wp[1] | ||
if n_agents > 1: | ||
lat_offset = rnd_sign * (-1.0) ** i * (1.0 / n_agents) | ||
x += lat_offset * np.cos(theta + np.pi / 2) | ||
y += lat_offset * np.sin(theta + np.pi / 2) | ||
|
||
pose = np.array([x, y, theta]) | ||
poses.append(pose) | ||
# find id of next waypoint which has mind <= dist <= maxd | ||
first_id, interval_len = ( | ||
None, | ||
None, | ||
) # first wp id with dist > mind, len of the interval btw first/last wp | ||
pnt_id = current_wp_id # moving pointer to scan the next waypoints | ||
dist = 0.0 | ||
while dist <= max_dist: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this whole while loop is only needed if you don't have the arclengths in the track. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's not exactly the same but even if there's a speedup it won't be that different |
||
# sanity check | ||
if pnt_id > n_waypoints - 1: | ||
pnt_id = 0 | ||
# increment distance | ||
x_diff = track.centerline.xs[pnt_id] - track.centerline.xs[pnt_id - 1] | ||
y_diff = track.centerline.ys[pnt_id] - track.centerline.ys[pnt_id - 1] | ||
dist = dist + np.linalg.norm( | ||
[y_diff, x_diff] | ||
) # approx distance by summing linear segments | ||
# look for sampling interval | ||
if first_id is None and dist >= min_dist: # not found first id yet | ||
first_id = pnt_id | ||
interval_len = 0 | ||
if ( | ||
first_id is not None and dist <= max_dist | ||
): # found first id, increment interval length | ||
interval_len += 1 | ||
pnt_id += 1 | ||
# sample next waypoint | ||
current_wp_id = (first_id + np.random.randint(0, interval_len + 1)) % ( | ||
n_waypoints | ||
) | ||
|
||
return np.array(poses) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not sure if it's i-th power or i(1/n)-th power?