Skip to content

Commit

Permalink
Merge branch 'dev-map-reset' into dev-dynamics-ros
Browse files Browse the repository at this point in the history
  • Loading branch information
AhmadAmine998 committed Oct 8, 2024
2 parents 967e1aa + 1ec0546 commit c23ad6d
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 0 deletions.
7 changes: 7 additions & 0 deletions f1tenth_gym/envs/reset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from .masked_reset import GridResetFn, AllTrackResetFn
from .map_reset import AllMapResetFn
from .reset_fn import ResetFn
from ..track import Track

Expand All @@ -10,6 +11,12 @@ def make_reset_fn(type: str | None, track: Track, num_agents: int, **kwargs) ->
try:
refline_token, reset_token, shuffle_token = type.split("_")

if refline_token == "map":
reset_fn = {"random": AllMapResetFn}[reset_token]
shuffle = {"static": False, "random": True}[shuffle_token]
return reset_fn(track=track, num_agents=num_agents, shuffle=shuffle, **kwargs)

# "cl" or "rl"
refline = {"cl": track.centerline, "rl": track.raceline}[refline_token]
reset_fn = {"grid": GridResetFn, "random": AllTrackResetFn}[reset_token]
shuffle = {"static": False, "random": True}[shuffle_token]
Expand Down
85 changes: 85 additions & 0 deletions f1tenth_gym/envs/reset/map_reset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from abc import abstractmethod

import cv2
import numpy as np

from .reset_fn import ResetFn
from .utils import sample_around_pose
from ..track import Track


class MapResetFn(ResetFn):
@abstractmethod
def get_mask(self) -> np.ndarray:
pass

def __init__(
self,
track: Track,
num_agents: int,
move_laterally: bool,
min_dist: float,
max_dist: float,
):
self.track = track
self.n_agents = num_agents
self.min_dist = min_dist
self.max_dist = max_dist
self.move_laterally = move_laterally
# Mask is a 2D array of booleans of where the agents can be placed
# Should acount for max_dist from obstacles
self.mask = self.get_mask()


def sample(self) -> np.ndarray:
# Random ample an x-y position from the mask
valid_x, valid_y = np.where(self.mask)
idx = np.random.choice(len(valid_x))
pose_x = valid_x[idx] * self.track.spec.resolution + self.track.spec.origin[0]
pose_y = valid_y[idx] * self.track.spec.resolution + self.track.spec.origin[1]
pose_theta = np.random.uniform(-np.pi, np.pi)
pose = np.array([pose_x, pose_y, pose_theta])

poses = sample_around_pose(
pose=pose,
n_agents=self.n_agents,
min_dist=self.min_dist,
max_dist=self.max_dist,
)
return poses

class AllMapResetFn(MapResetFn):
def __init__(
self,
track: Track,
num_agents: int,
move_laterally: bool = True,
shuffle: bool = True,
min_dist: float = 0.5,
max_dist: float = 1.0,
):
super().__init__(
track=track,
num_agents=num_agents,
move_laterally=move_laterally,
min_dist=min_dist,
max_dist=max_dist,
)
self.shuffle = shuffle

def get_mask(self) -> np.ndarray:
# Create mask from occupancy grid enlarged by max_dist
dilation_size = int(self.max_dist / self.track.spec.resolution)
kernel = np.ones((dilation_size, dilation_size), np.uint8)
inverted_occ_map = (255 - self.track.occupancy_map)
dilated = cv2.dilate(inverted_occ_map, kernel, iterations=1)
dilated_inverted = (255 - dilated)
return dilated_inverted == 255

def sample(self) -> np.ndarray:
poses = super().sample()

if self.shuffle:
np.random.shuffle(poses)

return poses
33 changes: 33 additions & 0 deletions f1tenth_gym/envs/reset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,36 @@ def sample_around_waypoint(
)

return np.array(poses)

def sample_around_pose(
pose: np.ndarray,
n_agents: int,
min_dist: float,
max_dist: float,
) -> np.ndarray:
"""
Compute n poses around a given pose.
It iteratively samples the next agent within a distance range from the previous one.
Note: no guarantee that the agents are on the track nor that they are not colliding with the environment.
Args:
- pose: the initial pose
- n_agents: the number of agents
- min_dist: the minimum distance between two consecutive agents
- max_dist: the maximum distance between two consecutive agents
"""
current_pose = pose

poses = []
for i in range(n_agents):
x, y, theta = current_pose
pose = np.array([x, y, theta])
poses.append(pose)
# sample next pose
dist = np.random.uniform(min_dist, max_dist)
theta = np.random.uniform(-np.pi, np.pi)
x += dist * np.cos(theta)
y += dist * np.sin(theta)
current_pose = np.array([x, y, theta])

return np.array(poses)
60 changes: 60 additions & 0 deletions f1tenth_gym/envs/track/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,66 @@ def from_track_name(track: str, track_scale: float = 1.0) -> Track:
print(ex)
raise FileNotFoundError(f"It could not load track {track}") from ex

@staticmethod
def from_track_path(path: pathlib.Path):
"""
Load track from track path.
Parameters
----------
path : pathlib.Path
path to the track yaml file
Returns
-------
Track
track object
Raises
------
FileNotFoundError
if the track cannot be loaded
"""
try:
if type(path) is str:
path = pathlib.Path(path)

track_spec = Track.load_spec(
track=path.stem, filespec=path
)

# load occupancy grid
# Image path is from path + image name from track_spec
image_path = path.parent / track_spec.image
image = Image.open(image_path).transpose(Transpose.FLIP_TOP_BOTTOM)
occupancy_map = np.array(image).astype(np.float32)
occupancy_map[occupancy_map <= 128] = 0.0
occupancy_map[occupancy_map > 128] = 255.0

# if exists, load centerline
if (path / f"{path.stem}_centerline.csv").exists():
centerline = Raceline.from_centerline_file(path / f"{path.stem}_centerline.csv")
else:
centerline = None

# if exists, load raceline
if (path / f"{path.stem}_raceline.csv").exists():
raceline = Raceline.from_raceline_file(path / f"{path.stem}_raceline.csv")
else:
raceline = centerline

return Track(
spec=track_spec,
filepath=str(path.absolute()),
ext=image_path.suffix,
occupancy_map=occupancy_map,
centerline=centerline,
raceline=raceline,
)
except Exception as ex:
print(ex)
raise FileNotFoundError(f"It could not load track {path}") from ex

@staticmethod
def from_refline(x: np.ndarray, y: np.ndarray, velx: np.ndarray):
"""
Expand Down

0 comments on commit c23ad6d

Please sign in to comment.