Skip to content

Commit

Permalink
[MISC] reformat examples/drone (#379)
Browse files Browse the repository at this point in the history
reformat examples/drone
  • Loading branch information
ziyanx02 authored Dec 29, 2024
1 parent ae8f556 commit 9d27d70
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
53 changes: 27 additions & 26 deletions examples/drone/hover_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import genesis as gs
from genesis.utils.geom import quat_to_xyz, transform_by_quat, inv_quat, transform_quat_by_quat


def gs_rand_float(lower, upper, shape, device):
return (upper - lower) * torch.rand(size=shape, device=device) + lower


class HoverEnv:
def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_viewer=False, device="cuda"):
self.device = torch.device(device)
Expand Down Expand Up @@ -52,18 +54,19 @@ def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_vie

# add target
if self.env_cfg["visualize_target"]:
self.target = self.scene.add_entity(morph=gs.morphs.Mesh(
file="meshes/sphere.obj",
scale=0.05,
fixed=True,
collision=False,
),
surface=gs.surfaces.Rough(
diffuse_texture=gs.textures.ColorTexture(
color=(1.0, 0.5, 0.5),
),
),
)
self.target = self.scene.add_entity(
morph=gs.morphs.Mesh(
file="meshes/sphere.obj",
scale=0.05,
fixed=True,
collision=False,
),
surface=gs.surfaces.Rough(
diffuse_texture=gs.textures.ColorTexture(
color=(1.0, 0.5, 0.5),
),
),
)
else:
self.target = None

Expand Down Expand Up @@ -120,9 +123,7 @@ def _resample_commands(self, envs_idx):

def _at_target(self):
at_target = (
(torch.norm(self.rel_pos, dim=1) < self.env_cfg["at_target_threshold"])
.nonzero(as_tuple=False)
.flatten()
(torch.norm(self.rel_pos, dim=1) < self.env_cfg["at_target_threshold"]).nonzero(as_tuple=False).flatten()
)
return at_target

Expand All @@ -134,7 +135,7 @@ def step(self, actions):
# self.drone.control_dofs_position(target_dof_pos)

# 14468 is hover rpm
self.drone.set_propellels_rpm((1 + exec_actions*0.8) * 14468.429183500699)
self.drone.set_propellels_rpm((1 + exec_actions * 0.8) * 14468.429183500699)
self.scene.step()

# update buffers
Expand All @@ -157,12 +158,12 @@ def step(self, actions):

# check termination and reset
self.crash_condition = (
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"]) |
(torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"]) |
(torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"]) |
(torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"]) |
(torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"]) |
(self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"])
| (torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"])
| (torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"])
| (torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"])
| (torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"])
| (self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
)
self.reset_buf = (self.episode_length_buf > self.max_episode_length) | self.crash_condition

Expand Down Expand Up @@ -248,15 +249,15 @@ def _reward_smooth(self):

def _reward_yaw(self):
yaw = self.base_euler[:, 2]
yaw = torch.where(yaw > 180, yaw - 360, yaw)/180*3.14159 # use rad for yaw_reward
yaw = torch.where(yaw > 180, yaw - 360, yaw) / 180 * 3.14159 # use rad for yaw_reward
yaw_rew = torch.exp(self.reward_cfg["yaw_lambda"] * torch.abs(yaw))
return yaw_rew

def _reward_angular(self):
angular_rew = torch.norm(self.base_ang_vel/3.14159, dim=1)
angular_rew = torch.norm(self.base_ang_vel / 3.14159, dim=1)
return angular_rew

def _reward_crash(self):
crash_rew = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float)
crash_rew[self.crash_condition] = 1
return crash_rew
return crash_rew
3 changes: 2 additions & 1 deletion examples/drone/hover_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main():

obs, _ = env.reset()

max_sim_step = int(env_cfg["episode_length_s"]*env_cfg["max_visualize_FPS"])
max_sim_step = int(env_cfg["episode_length_s"] * env_cfg["max_visualize_FPS"])
with torch.no_grad():
if args.record:
env.cam.start_recording()
Expand All @@ -59,6 +59,7 @@ def main():
actions = policy(obs)
obs, _, rews, dones, infos = env.step(actions)


if __name__ == "__main__":
main()

Expand Down
3 changes: 1 addition & 2 deletions examples/drone/hover_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


def get_train_cfg(exp_name, max_iterations):

train_cfg_dict = {
"algorithm": {
"clip_param": 0.2,
Expand Down Expand Up @@ -95,7 +94,7 @@ def get_cfgs():
"yaw": 0.01,
"angular": -2e-4,
"crash": -10.0,
}
},
}
command_cfg = {
"num_commands": 3,
Expand Down

0 comments on commit 9d27d70

Please sign in to comment.