diff --git a/examples/create_edt.py b/examples/create_edt.py new file mode 100644 index 00000000..c1959237 --- /dev/null +++ b/examples/create_edt.py @@ -0,0 +1,41 @@ +from f110_gym.envs.track import Track +from scipy.ndimage import distance_transform_edt as edt +import numpy as np + +DEFAULT_MAP_NAMES = [ + "Austin", + "BrandsHatch", + "Budapest", + "Catalunya", + "Hockenheim", + "IMS", + "Melbourne", + "MexicoCity", + "Montreal", + "Monza", + "MoscowRaceway", + "Nuerburgring", + "Oschersleben", + "Sakhir", + "SaoPaulo", + "Sepang", + "Shanghai", + "Silverstone", + "Sochi", + "Spa", + "Spielberg", + "YasMarina", + "Zandvoort", +] + +for track_name in DEFAULT_MAP_NAMES: + track = Track.from_track_name(track_name) + occupancy_map = track.occupancy_map + resolution = track.spec.resolution + + dt = resolution * edt(occupancy_map) + + # saving + np.save(track.filepath, dt) + + track_wedt = Track.from_track_name(track_name) diff --git a/f1tenth_gym/envs/laser_models.py b/f1tenth_gym/envs/laser_models.py index 8d3c2e9b..69d869a1 100644 --- a/f1tenth_gym/envs/laser_models.py +++ b/f1tenth_gym/envs/laser_models.py @@ -488,7 +488,10 @@ def set_map(self, map: str | Track): self.orig_c = np.cos(self.origin[2]) # get the distance transform - self.dt = get_dt(self.map_img, self.map_resolution) + if self.track.edt is not None: + self.dt = self.track.edt + else: + self.dt = get_dt(self.map_img, self.map_resolution) return True diff --git a/f1tenth_gym/envs/track/track.py b/f1tenth_gym/envs/track/track.py index d0bd6585..e18ea9eb 100644 --- a/f1tenth_gym/envs/track/track.py +++ b/f1tenth_gym/envs/track/track.py @@ -2,6 +2,7 @@ import pathlib from dataclasses import dataclass from typing import Tuple, Optional +import os import numpy as np import yaml @@ -31,6 +32,7 @@ class Track: filepath: str ext: str occupancy_map: np.ndarray + edt: np.ndarray centerline: Raceline raceline: Raceline @@ -40,6 +42,7 @@ def __init__( filepath: str, ext: str, occupancy_map: np.ndarray, + edt: Optional[np.ndarray] = None, centerline: Optional[Raceline] = None, raceline: Optional[Raceline] = None, ): @@ -56,6 +59,8 @@ def __init__( file extension of the track image file occupancy_map : np.ndarray occupancy grid map + edt : np.ndarray + distance transform of the map centerline : Raceline, optional centerline of the track, by default None raceline : Raceline, optional @@ -65,6 +70,7 @@ def __init__( self.filepath = filepath self.ext = ext self.occupancy_map = occupancy_map + self.edt = edt self.centerline = centerline self.raceline = raceline @@ -125,6 +131,18 @@ def from_track_name(track: str): occupancy_map[occupancy_map <= 128] = 0.0 occupancy_map[occupancy_map > 128] = 255.0 + # if exists and it has been created for the current map image, load edt + map_filepath = (track_dir / map_filename).absolute() + track_filepath = map_filepath.with_suffix("") + edt_filepath = track_dir / f"{track}_map.npy" + if edt_filepath.exists() and os.path.getmtime(edt_filepath) >= os.path.getmtime(map_filepath): + edt = np.load(track_dir / f"{track}_map.npy") + else: + from scipy.ndimage import distance_transform_edt as edt + resolution = track_spec.resolution + edt = resolution * edt(occupancy_map) + np.save(track_filepath, edt) + # if exists, load centerline if (track_dir / f"{track}_centerline.csv").exists(): centerline = Raceline.from_centerline_file( @@ -146,6 +164,7 @@ def from_track_name(track: str): filepath=str((track_dir / map_filename.stem).absolute()), ext=map_filename.suffix, occupancy_map=occupancy_map, + edt=edt, centerline=centerline, raceline=raceline, ) diff --git a/f1tenth_gym/envs/track/utils.py b/f1tenth_gym/envs/track/utils.py index e3e88a09..942a1499 100644 --- a/f1tenth_gym/envs/track/utils.py +++ b/f1tenth_gym/envs/track/utils.py @@ -24,7 +24,7 @@ def find_track_dir(track_name: str) -> pathlib.Path: FileNotFoundError if no map directory matching the track name is found """ - map_dir = pathlib.Path(__file__).parent.parent.parent.parent / "maps" + map_dir = pathlib.Path(__file__).parent.parent.parent.parent.parent / "maps" if not (map_dir / track_name).exists(): print("Downloading Files for: " + track_name) diff --git a/tests/test_track.py b/tests/test_track.py index 98ff7bbb..d35b206e 100644 --- a/tests/test_track.py +++ b/tests/test_track.py @@ -1,3 +1,5 @@ +import datetime +import os import pathlib import time import unittest @@ -138,3 +140,29 @@ def test_download_racetrack(self): # rename the backup track dir to its original name track_backup_dir = find_track_dir(tmp_dir.stem) track_backup_dir.rename(track_dir) + + def test_edt_update(self): + """ + Test the re-creation of the edt if the map modification time is more recent. + """ + track = Track.from_track_name("Spielberg") + + # set the map image modification/access time to now + now = datetime.datetime.now() + dt_epoch = now.timestamp() + map_filepath = pathlib.Path(track.filepath).parent / track.spec.image + os.utime(map_filepath, (dt_epoch, dt_epoch)) + + # check the edt modification time is now < the map image time + edt_filepath = map_filepath.with_suffix(".npy") + self.assertTrue(os.path.getmtime(map_filepath) > os.path.getmtime(edt_filepath), + f"expected the map image modification time to be > the edt modification time") + + # this should force the edt to be recomputed + # check the edt modification time is not > the map image time + track2 = Track.from_track_name("Spielberg") + self.assertTrue(os.path.getmtime(map_filepath) < os.path.getmtime(edt_filepath), + f"expected the map image modification time to be > the edt modification time") + + # check consistency in the maps edts + self.assertTrue(np.allclose(track.edt, track2.edt), f"expected the same edt transform for {track.spec.name}")