Skip to content

Commit

Permalink
Merge branch 'Genesis-Embodied-AI:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
00make authored Dec 22, 2024
2 parents 57c6d3c + 3201b23 commit 5eab75a
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 28 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ pip install genesis-world # Requires Python >=3.9;

You also need to install **PyTorch** following the [official instructions](https://pytorch.org/get-started/locally/).

If you would like to try out the latest version, we suggest you to git clone from the repo and do `pip install -e .` instead of via PyPI.

### Documentation

Please refer to our [documentation site](https://genesis-world.readthedocs.io/en/latest/user_guide/index.html) for detailed installation steps, tutorials and API references.
Please refer to our [documentation site (English)](https://genesis-world.readthedocs.io/en/latest/user_guide/index.html) / [(Chinese)](https://genesis-world.readthedocs.io/zh-cn/latest/user_guide/index.html) for detailed installation steps, tutorials and API references.

## Contributing to Genesis

Expand Down
2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pip install genesis-world # 需要 Python >=3.9;

### 文档

请参阅我们的 [文档网站](https://genesis-world.readthedocs.io/en/latest/user_guide/index.html) 以获取详细的安装步骤、教程和 API 参考。
请参阅我们的 [文档网站(英文)](https://genesis-world.readthedocs.io/en/latest/user_guide/index.html)/[(中文)](https://genesis-world.readthedocs.io/zh-cn/latest/user_guide/index.html)以获取详细的安装步骤、教程和 API 参考。

## 贡献 Genesis

Expand Down
2 changes: 1 addition & 1 deletion examples/coupling/cloth_on_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def main():
args = parser.parse_args()

########################## init ##########################
gs.init(seed=0, precision="32", logging_level="debug")
gs.init(seed=0, precision="32", logging_level="debug", backend=gs.cpu if args.cpu else gs.gpu)

scene = gs.Scene(
sim_options=gs.options.SimOptions(
Expand Down
78 changes: 78 additions & 0 deletions examples/rigid/apply_external_force_torque.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import argparse
import numpy as np
import genesis as gs

from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-v", "--vis", action="store_true", default=False)
args = parser.parse_args()

########################## init ##########################
gs.init(backend=gs.gpu)

########################## create a scene ##########################
viewer_options = gs.options.ViewerOptions(
camera_pos=(0, -3.5, 2.5),
camera_lookat=(0.0, 0.0, 1.0),
camera_fov=40,
max_FPS=60,
)

scene = gs.Scene(
viewer_options=viewer_options,
sim_options=gs.options.SimOptions(
dt=0.01,
),
show_viewer=args.vis,
)

########################## entities ##########################
plane = scene.add_entity(
gs.morphs.Plane(),
)
cube = scene.add_entity(
gs.morphs.Box(
pos=(0, 0, 1.0),
size=(0.2, 0.2, 0.2),
),
)
########################## build ##########################
scene.build()

for solver in scene.sim.solvers:
if not isinstance(solver, RigidSolver):
continue
rigid_solver = solver

link_idx = [
1,
]
rotation_direction = 1
for i in range(1000):
cube_pos = rigid_solver.get_links_pos(link_idx)
cube_pos[:, 2] -= 1
force = -100 * cube_pos
rigid_solver.apply_links_external_force(
force=force,
links_idx=link_idx,
)

torque = [
[0, 0, rotation_direction * 5],
]
rigid_solver.apply_links_external_torque(
torque=torque,
links_idx=link_idx,
)

scene.step()

if (i + 50) % 100 == 0:
rotation_direction *= -1


if __name__ == "__main__":
main()
101 changes: 101 additions & 0 deletions examples/rigid/domain_randomization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import argparse

import numpy as np
import torch

import genesis as gs


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-v", "--vis", action="store_true", default=False)
args = parser.parse_args()

########################## init ##########################
gs.init(seed=0, precision="32", logging_level="debug")

########################## create a scene ##########################
scene = gs.Scene(
viewer_options=gs.options.ViewerOptions(
camera_pos=(0.0, -2, 1.5),
camera_lookat=(0.0, 0.0, 0.5),
camera_fov=40,
max_FPS=200,
),
show_viewer=args.vis,
rigid_options=gs.options.RigidOptions(
dt=0.01,
constraint_solver=gs.constraint_solver.Newton,
),
)

########################## entities ##########################
scene.add_entity(
gs.morphs.Plane(),
)
robot = scene.add_entity(
gs.morphs.URDF(
file="urdf/go2/urdf/go2.urdf",
pos=(0, 0, 0.4),
),
)
########################## build ##########################
n_envs = 8
scene.build(n_envs=n_envs)

########################## domain randomization ##########################
robot.set_friction_ratio(
friction_ratio=0.5 + torch.rand(scene.n_envs, robot.n_links),
link_indices=np.arange(0, robot.n_links),
)
robot.set_mass_shift(
mass_shift=-0.5 + torch.rand(scene.n_envs, robot.n_links),
link_indices=np.arange(0, robot.n_links),
)
robot.set_COM_shift(
com_shift=-0.05 + 0.1 * torch.rand(scene.n_envs, robot.n_links, 3),
link_indices=np.arange(0, robot.n_links),
)

joint_names = [
"FR_hip_joint",
"FR_thigh_joint",
"FR_calf_joint",
"FL_hip_joint",
"FL_thigh_joint",
"FL_calf_joint",
"RR_hip_joint",
"RR_thigh_joint",
"RR_calf_joint",
"RL_hip_joint",
"RL_thigh_joint",
"RL_calf_joint",
]
motor_dofs = [robot.get_joint(name).dof_idx_local for name in joint_names]

robot.set_dofs_kp(np.full(12, 20), motor_dofs)
robot.set_dofs_kv(np.full(12, 1), motor_dofs)
default_dof_pos = np.array(
[
0.0,
0.8,
-1.5,
0.0,
0.8,
-1.5,
0.0,
1.0,
-1.5,
0.0,
1.0,
-1.5,
]
)
robot.control_dofs_position(default_dof_pos, motor_dofs)

for i in range(1000):
scene.step()


if __name__ == "__main__":
main()
40 changes: 37 additions & 3 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,10 @@ def set_friction_ratio(self, friction_ratio, link_indices, envs_idx=None):
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
geom_indices = []
for i in link_indices:
for j in range(self._links[i].n_geoms):
geom_indices.append(self._links[i]._geom_start + j)
self._solver.set_geoms_friction_ratio(
torch.cat(
[
Expand All @@ -2190,9 +2194,7 @@ def set_friction_ratio(self, friction_ratio, link_indices, envs_idx=None):
],
dim=-1,
),
torch.tensor(
[[self._links[j]._geom_start + i for i in range(self._links[j].n_geoms)] for j in link_indices]
).view(-1),
geom_indices,
envs_idx,
)

Expand All @@ -2217,6 +2219,38 @@ def set_friction(self, friction):
for link in self._links:
link.set_friction(friction)

def set_mass_shift(self, mass_shift, link_indices, envs_idx=None):
"""
Set the mass shift of specified links.
Parameters
----------
mass : torch.Tensor, shape (n_envs, n_links)
The mass shift
link_indices : array_like
The indices of the links to set mass shift.
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
for i in range(len(link_indices)):
link_indices[i] += self._link_start
self._solver.set_links_mass_shift(mass_shift, link_indices, envs_idx)

def set_COM_shift(self, com_shift, link_indices, envs_idx=None):
"""
Set the center of mass (COM) shift of specified links.
Parameters
----------
com : torch.Tensor, shape (n_envs, n_links, 3)
The COM shift
link_indices : array_like
The indices of the links to set COM shift.
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
for i in range(len(link_indices)):
link_indices[i] += self._link_start
self._solver.set_links_COM_shift(com_shift, link_indices, envs_idx)

@gs.assert_built
def get_mass(self):
"""
Expand Down
5 changes: 5 additions & 0 deletions genesis/engine/solvers/rigid/collider_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,9 +991,14 @@ def _func_mpr(self, i_ga, i_gb, i_b):
i_gb, i_ga = i_ga, i_gb

is_plane = self._solver.geoms_info[i_ga].type == gs.GEOM_TYPE.PLANE

i_la = self._solver.geoms_info[i_ga].link_idx
i_lb = self._solver.geoms_info[i_gb].link_idx
is_self_pair = self._solver.links_info.root_idx[i_la] == self._solver.links_info.root_idx[i_lb]
multi_contact = (
self._solver.geoms_info[i_ga].type != gs.GEOM_TYPE.SPHERE
and self._solver.geoms_info[i_gb].type != gs.GEOM_TYPE.SPHERE
and not is_self_pair
)
if is_plane:
self._func_plane_contact(i_ga, i_gb, multi_contact, i_b)
Expand Down
12 changes: 2 additions & 10 deletions genesis/engine/solvers/rigid/constraint_solver_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,7 @@ def add_collision_constraints(self):

if ti.static(self.sparse_solve):
self.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs

imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time, impact.sol_params, -impact.penetration, jac_qvel
)
imp, aref = gu.imp_aref(impact.sol_params, -impact.penetration, jac_qvel)

diag = t + impact.friction * impact.friction * t
diag *= 2 * impact.friction * impact.friction * (1 - imp) / ti.max(imp, gs.EPS)
Expand Down Expand Up @@ -199,12 +196,7 @@ def add_joint_limit_constraints(self):

jac = side
jac_qvel = jac * self._solver.dofs_state[i_d, i_b].vel
imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time,
self._solver.dofs_info[i_d].sol_params,
pos,
jac_qvel,
)
imp, aref = gu.imp_aref(self._solver.dofs_info[i_d].sol_params, pos, jac_qvel)
diag = self._solver.dofs_info[i_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS)
aref = aref * (pos < 0)
if pos < 0:
Expand Down
11 changes: 2 additions & 9 deletions genesis/engine/solvers/rigid/constraint_solver_decomp_island.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def add_collision_constraints(self, island, i_b):
if ti.static(self.sparse_solve):
self.jac_n_relevant_dofs[n_con, i_b] = con_n_relevant_dofs

imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time, impact.sol_params, -impact.penetration, jac_qvel
)
imp, aref = gu.imp_aref(impact.sol_params, -impact.penetration, jac_qvel)

diag = t + impact.friction * impact.friction * t
diag *= 2 * impact.friction * impact.friction * (1 - imp) / ti.max(imp, gs.EPS)
Expand Down Expand Up @@ -235,12 +233,7 @@ def add_joint_limit_constraints(self, island, i_b):

jac = side
jac_qvel = jac * self._solver.dofs_state[i_d, i_b].vel
imp, aref = gu.imp_aref(
self._solver._sol_contact_resolve_time,
self._solver.dofs_info[i_d].sol_params,
pos,
jac_qvel,
)
imp, aref = gu.imp_aref(self._solver.dofs_info[i_d].sol_params, pos, jac_qvel)
diag = self._solver.dofs_info[i_d].invweight * (pos < 0) * (1 - imp) / (imp + gs.EPS)
aref = aref * (pos < 0)
if pos < 0:
Expand Down
Loading

0 comments on commit 5eab75a

Please sign in to comment.