Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
YilingQiao committed Dec 27, 2024
1 parent 7707fbe commit affa7d9
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
20 changes: 11 additions & 9 deletions examples/drone/hover_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import genesis as gs
from genesis.utils.geom import quat_to_xyz, transform_by_quat, inv_quat, transform_quat_by_quat


def gs_rand_float(lower, upper, shape, device):
return (upper - lower) * torch.rand(size=shape, device=device) + lower


class HoverEnv:
def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_viewer=False, device="cuda"):
self.device = torch.device(device)
Expand Down Expand Up @@ -213,17 +215,17 @@ def _reward_target(self):
def _reward_smooth(self):
smooth_rew = torch.sum(torch.square(self.actions - self.last_actions), dim=1)
return smooth_rew

def _reward_crash(self):
crash_rew = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float)

crash_condition = (
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"]) |
(torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"]) |
(torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"]) |
(torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"]) |
(torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"]) |
(self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"])
| (torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"])
| (torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"])
| (torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"])
| (torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"])
| (self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
)
crash_rew[crash_condition] = -1
return crash_rew
return crash_rew
2 changes: 1 addition & 1 deletion examples/drone/hover_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_cfgs():
},
}
reward_cfg = {
"reward_scales":{
"reward_scales": {
"target": 5.0,
"smooth": -0.001,
"crash": 1.0,
Expand Down
1 change: 1 addition & 0 deletions genesis/ext/pyrender/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from tkinter import Tk
from tkinter import filedialog

try:
root = Tk()
root.withdraw()
Expand Down

0 comments on commit affa7d9

Please sign in to comment.