Skip to content

Commit

Permalink
test all examples; fix empty dofs; fix minor
Browse files Browse the repository at this point in the history
  • Loading branch information
zswang666 committed Dec 22, 2024
1 parent e008185 commit 87c53a8
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions genesis/engine/solvers/rigid/rigid_solver_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,19 +312,21 @@ def _init_dof_fields(self):
)

joints = self.joints
self._kernel_init_dof_fields(
dofs_motion_ang=np.concatenate([joint.dofs_motion_ang for joint in joints], dtype=gs.np_float),
dofs_motion_vel=np.concatenate([joint.dofs_motion_vel for joint in joints], dtype=gs.np_float),
dofs_limit=np.concatenate([joint.dofs_limit for joint in joints], dtype=gs.np_float),
dofs_invweight=np.concatenate([joint.dofs_invweight for joint in joints], dtype=gs.np_float),
dofs_stiffness=np.concatenate([joint.dofs_stiffness for joint in joints], dtype=gs.np_float),
dofs_sol_params=np.concatenate([joint.dofs_sol_params for joint in joints], dtype=gs.np_float),
dofs_damping=np.concatenate([joint.dofs_damping for joint in joints], dtype=gs.np_float),
dofs_armature=np.concatenate([joint.dofs_armature for joint in joints], dtype=gs.np_float),
dofs_kp=np.concatenate([joint.dofs_kp for joint in joints], dtype=gs.np_float),
dofs_kv=np.concatenate([joint.dofs_kv for joint in joints], dtype=gs.np_float),
dofs_force_range=np.concatenate([joint.dofs_force_range for joint in joints], dtype=gs.np_float),
)
is_nonempty = np.concatenate([joint.dofs_motion_ang for joint in joints], dtype=gs.np_float).shape[0] > 0
if is_nonempty: # handle the case where there is a link with no dofs -- otherwise may cause invalid memory
self._kernel_init_dof_fields(
dofs_motion_ang=np.concatenate([joint.dofs_motion_ang for joint in joints], dtype=gs.np_float),
dofs_motion_vel=np.concatenate([joint.dofs_motion_vel for joint in joints], dtype=gs.np_float),
dofs_limit=np.concatenate([joint.dofs_limit for joint in joints], dtype=gs.np_float),
dofs_invweight=np.concatenate([joint.dofs_invweight for joint in joints], dtype=gs.np_float),
dofs_stiffness=np.concatenate([joint.dofs_stiffness for joint in joints], dtype=gs.np_float),
dofs_sol_params=np.concatenate([joint.dofs_sol_params for joint in joints], dtype=gs.np_float),
dofs_damping=np.concatenate([joint.dofs_damping for joint in joints], dtype=gs.np_float),
dofs_armature=np.concatenate([joint.dofs_armature for joint in joints], dtype=gs.np_float),
dofs_kp=np.concatenate([joint.dofs_kp for joint in joints], dtype=gs.np_float),
dofs_kv=np.concatenate([joint.dofs_kv for joint in joints], dtype=gs.np_float),
dofs_force_range=np.concatenate([joint.dofs_force_range for joint in joints], dtype=gs.np_float),
)

# just in case
self.dofs_state.force.fill(0)
Expand Down Expand Up @@ -3281,7 +3283,7 @@ def _kernel_set_links_pos(
ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL)
for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]):
i_l = links_idx[i_l_]
I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l
I_l = [i_l, i_b_] if ti.static(self._options.batch_links_info) else i_l
if self.links_info[I_l].is_fixed: # change links_state directly as the link's pose is not contained in qpos
for i in ti.static(range(3)):
self.links_state[i_l, envs_idx[i_b_]].pos[i] = pos[i_b_, i_l_, i]
Expand Down Expand Up @@ -3315,7 +3317,7 @@ def _kernel_set_links_quat(
ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.PARTIAL)
for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]):
i_l = links_idx[i_l_]
I_l = [i_l, i_b] if ti.static(self._options.batch_links_info) else i_l
I_l = [i_l, i_b_] if ti.static(self._options.batch_links_info) else i_l
if self.links_info[I_l].is_fixed: # change links_state directly as the link's pose is not contained in qpos
for i in ti.static(range(4)):
self.links_state[i_l, envs_idx[i_b_]].quat[i] = quat[i_b_, i_l_, i]
Expand Down

0 comments on commit 87c53a8

Please sign in to comment.