Skip to content

Commit

Permalink
Define jaxsim.typing.VelRepr as Int
Browse files Browse the repository at this point in the history
Co-authored-by: Diego Ferigo <[email protected]>
  • Loading branch information
flferretti and diegoferigo committed Jun 14, 2024
1 parent 5ba13af commit 88d5a2d
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 28 deletions.
8 changes: 4 additions & 4 deletions src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
Base class for model data structures with velocity representation.
"""

velocity_representation: int = dataclasses.field(
velocity_representation: jtp.VelRepr = dataclasses.field(
default=VelRepr.Inertial, kw_only=True
)

@contextlib.contextmanager
def switch_velocity_representation(
self, velocity_representation: int
self, velocity_representation: jtp.VelRepr
) -> ContextManager[Self]:
"""
Context manager to temporarily switch the velocity representation.
Expand Down Expand Up @@ -83,7 +83,7 @@ def switch_velocity_representation(
@functools.partial(jax.jit, static_argnames=["is_force"])
def inertial_to_other_representation(
array: jtp.Array,
other_representation: int,
other_representation: jtp.VelRepr,
transform: jtp.Matrix,
*,
is_force: bool,
Expand Down Expand Up @@ -148,7 +148,7 @@ def to_mixed():
@functools.partial(jax.jit, static_argnames=["is_force"])
def other_representation_to_inertial(
array: jtp.Array,
other_representation: int,
other_representation: jtp.VelRepr,
transform: jtp.Matrix,
*,
is_force: bool,
Expand Down
12 changes: 6 additions & 6 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
@staticmethod
def zero(
model: js.model.JaxSimModel,
velocity_representation: int = VelRepr.Inertial,
velocity_representation: jtp.VelRepr = VelRepr.Inertial,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with zero state.
Expand Down Expand Up @@ -112,7 +112,7 @@ def build(
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
soft_contacts_state: js.ode_data.SoftContactsState | None = None,
soft_contacts_params: jaxsim.rbda.SoftContactsParams | None = None,
velocity_representation: int = VelRepr.Inertial,
velocity_representation: jtp.VelRepr = VelRepr.Inertial,
time: jtp.FloatLike | None = None,
) -> JaxSimModelData:
"""
Expand Down Expand Up @@ -631,7 +631,7 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
def reset_base_linear_velocity(
self,
linear_velocity: jtp.VectorLike,
velocity_representation: int | None = None,
velocity_representation: jtp.VelRepr | None = None,
) -> Self:
"""
Reset the base linear velocity.
Expand Down Expand Up @@ -659,7 +659,7 @@ def reset_base_linear_velocity(
def reset_base_angular_velocity(
self,
angular_velocity: jtp.VectorLike,
velocity_representation: int | None = None,
velocity_representation: jtp.VelRepr | None = None,
) -> Self:
"""
Reset the base angular velocity.
Expand Down Expand Up @@ -687,7 +687,7 @@ def reset_base_angular_velocity(
def reset_base_velocity(
self,
base_velocity: jtp.VectorLike,
velocity_representation: int | None = None,
velocity_representation: jtp.VelRepr | None = None,
) -> Self:
"""
Reset the base 6D velocity.
Expand Down Expand Up @@ -732,7 +732,7 @@ def random_model_data(
model: js.model.JaxSimModel,
*,
key: jax.Array | None = None,
velocity_representation: int | None = None,
velocity_representation: jtp.VelRepr | None = None,
base_pos_bounds: tuple[
jtp.FloatLike | Sequence[jtp.FloatLike],
jtp.FloatLike | Sequence[jtp.FloatLike],
Expand Down
10 changes: 5 additions & 5 deletions src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
@staticmethod
def zero(
model: js.model.JaxSimModel,
velocity_representation: int = VelRepr.Inertial,
velocity_representation: jtp.VelRepr = VelRepr.Inertial,
) -> JaxSimModelReferences:
"""
Create a `JaxSimModelReferences` object with zero references.
Expand All @@ -55,7 +55,7 @@ def build(
joint_force_references: jtp.Vector | None = None,
link_forces: jtp.Matrix | None = None,
data: js.data.JaxSimModelData | None = None,
velocity_representation: int | None = None,
velocity_representation: jtp.VelRepr | None = None,
) -> JaxSimModelReferences:
"""
Create a `JaxSimModelReferences` object with the given references.
Expand Down Expand Up @@ -225,7 +225,7 @@ def check_not_inertial() -> None:
false_fun=lambda: None,
)

def not_inertial(velocity_representation: int) -> jtp.Matrix:
def not_inertial(velocity_representation: jtp.VelRepr) -> jtp.Matrix:
# Helper function to convert a single 6D force to the active representation
# considering as body the link (i.e. L_f_L and LW_f_L).
def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix:
Expand Down Expand Up @@ -468,15 +468,15 @@ def check_not_inertial() -> None:
)

# If inertial-fixed representation, we can directly store the link forces.
def inertial(velocity_representation: int) -> JaxSimModelReferences:
def inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences:
W_f_L = f_L
return replace(
forces=self.input.physics_model.f_ext.at[link_idxs, :].set(
W_f0_L + W_f_L
)
)

def not_inertial(velocity_representation: int) -> JaxSimModelReferences:
def not_inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences:
# Helper function to convert a single 6D force to the inertial representation
# considering as body the link (i.e. L_f_L and LW_f_L).
def convert_using_link_frame(
Expand Down
2 changes: 2 additions & 0 deletions src/jaxsim/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,5 @@
IntLike = Int
BoolLike = Bool
FloatLike = Float

VelRepr = Int
3 changes: 2 additions & 1 deletion tests/test_api_com.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import pytest

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import VelRepr

from . import utils_idyntree


def test_com_properties(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: int,
velocity_representation: jtp.VelRepr,
prng_key: jax.Array,
):

Expand Down
3 changes: 2 additions & 1 deletion tests/test_api_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import VelRepr
from jaxsim.utils import Mutability

Expand All @@ -21,7 +22,7 @@ def test_data_valid(

def test_data_joint_indexing(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: int,
velocity_representation: jtp.VelRepr,
prng_key: jax.Array,
):

Expand Down
3 changes: 2 additions & 1 deletion tests/test_api_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import VelRepr

from . import utils_idyntree
Expand Down Expand Up @@ -95,7 +96,7 @@ def test_frame_transforms(

def test_frame_jacobians(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: int,
velocity_representation: jtp.VelRepr,
prng_key: jax.Array,
):

Expand Down
7 changes: 4 additions & 3 deletions tests/test_api_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp
from jaxsim import VelRepr

from . import utils_idyntree
Expand Down Expand Up @@ -117,7 +118,7 @@ def test_link_transforms(

def test_link_jacobians(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: int,
velocity_representation: jtp.VelRepr,
prng_key: jax.Array,
):

Expand Down Expand Up @@ -184,7 +185,7 @@ def test_link_jacobians(

def test_link_bias_acceleration(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: int,
velocity_representation: jtp.VelRepr,
prng_key: jax.Array,
):

Expand Down Expand Up @@ -216,7 +217,7 @@ def test_link_bias_acceleration(

def test_link_jacobian_derivative(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
velocity_representation: jtp.VelRepr,
prng_key: jax.Array,
):

Expand Down
7 changes: 4 additions & 3 deletions tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp
from jaxsim import VelRepr

from . import utils_idyntree
Expand Down Expand Up @@ -221,7 +222,7 @@ def test_model_creation_and_reduction(

def test_model_properties(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: int,
velocity_representation: jtp.VelRepr,
prng_key: jax.Array,
):

Expand Down Expand Up @@ -268,7 +269,7 @@ def test_model_properties(
def test_model_rbda(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
velocity_representation: int,
velocity_representation: jtp.VelRepr,
):

model = jaxsim_models_types
Expand Down Expand Up @@ -479,7 +480,7 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:

def test_model_fd_id_consistency(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: int,
velocity_representation: jtp.VelRepr,
prng_key: jax.Array,
):

Expand Down
2 changes: 1 addition & 1 deletion tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

def get_random_data_and_references(
model: js.model.JaxSimModel,
velocity_representation: int,
velocity_representation: jtp.VelRepr,
key: jax.Array,
) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]:

Expand Down
3 changes: 2 additions & 1 deletion tests/test_contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import pytest

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import VelRepr


def test_collidable_point_jacobians(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: int,
velocity_representation: jtp.VelRepr,
prng_key: jax.Array,
):

Expand Down
5 changes: 3 additions & 2 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import jaxsim.api as js
import jaxsim.integrators
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim import VelRepr


def test_box_with_external_forces(
jaxsim_model_box: js.model.JaxSimModel,
velocity_representation: int,
velocity_representation: jtp.VelRepr,
):
"""
This test simulates a box falling due to gravity.
Expand Down Expand Up @@ -95,7 +96,7 @@ def test_box_with_external_forces(

def test_box_with_zero_gravity(
jaxsim_model_box: js.model.JaxSimModel,
velocity_representation: VelRepr,
velocity_representation: jtp.VelRepr,
prng_key: jnp.ndarray,
):

Expand Down

0 comments on commit 88d5a2d

Please sign in to comment.