Skip to content

Commit

Permalink
load edt if available instead of creating on the fly
Browse files Browse the repository at this point in the history
  • Loading branch information
hzheng40 committed Feb 28, 2024
1 parent 4a001f7 commit 1589cca
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
5 changes: 4 additions & 1 deletion gym/f110_gym/envs/laser_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,10 @@ def set_map(self, map: str | Track):
self.orig_c = np.cos(self.origin[2])

# get the distance transform
self.dt = get_dt(self.map_img, self.map_resolution)
if self.track.edt is not None:
self.dt = self.track.edt
else:
self.dt = get_dt(self.map_img, self.map_resolution)

return True

Expand Down
12 changes: 12 additions & 0 deletions gym/f110_gym/envs/track/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Track:
filepath: str
ext: str
occupancy_map: np.ndarray
edt: np.ndarray
centerline: Raceline
raceline: Raceline

Expand All @@ -40,6 +41,7 @@ def __init__(
filepath: str,
ext: str,
occupancy_map: np.ndarray,
edt: Optional[np.ndarray] = None,
centerline: Optional[Raceline] = None,
raceline: Optional[Raceline] = None,
):
Expand All @@ -56,6 +58,8 @@ def __init__(
file extension of the track image file
occupancy_map : np.ndarray
occupancy grid map
edt : np.ndarray
distance transform of the map
centerline : Raceline, optional
centerline of the track, by default None
raceline : Raceline, optional
Expand All @@ -65,6 +69,7 @@ def __init__(
self.filepath = filepath
self.ext = ext
self.occupancy_map = occupancy_map
self.edt = edt
self.centerline = centerline
self.raceline = raceline

Expand Down Expand Up @@ -125,6 +130,12 @@ def from_track_name(track: str):
occupancy_map[occupancy_map <= 128] = 0.0
occupancy_map[occupancy_map > 128] = 255.0

# if exists, load edt
if (track_dir / f"{track}_map.npy").exists():
edt = np.load(track_dir / f"{track}_map.npy")
else:
edt = None

# if exists, load centerline
if (track_dir / f"{track}_centerline.csv").exists():
centerline = Raceline.from_centerline_file(
Expand All @@ -146,6 +157,7 @@ def from_track_name(track: str):
filepath=str((track_dir / map_filename.stem).absolute()),
ext=map_filename.suffix,
occupancy_map=occupancy_map,
edt=edt,
centerline=centerline,
raceline=raceline,
)
Expand Down

0 comments on commit 1589cca

Please sign in to comment.