Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revive action primitives examples and tests #842

Merged
merged 87 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
421ca70
load robot bugfix
Aug 21, 2024
e7ef748
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2024
495f49f
add the comment back
Aug 21, 2024
8da4ee3
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Aug 21, 2024
609da8e
action primitive navigation fixed
Aug 26, 2024
2a70755
Merge branch 'og-develop' of https://github.com/StanfordVL/OmniGibson…
Aug 26, 2024
6db6e88
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2024
f7601e5
ik controller to joint controller
Aug 27, 2024
0544275
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Aug 27, 2024
1ac8497
ik control to joint control
Aug 27, 2024
51d0cf6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
aa58d9b
remove empty lines
Aug 27, 2024
bf24672
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Aug 27, 2024
53494c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
6db2377
test functional 7/10
Sep 1, 2024
2b3278d
test 7/10 functional
Sep 1, 2024
a155e70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2024
6b3c8ff
Merge branch 'og-develop' of https://github.com/StanfordVL/OmniGibson…
Sep 4, 2024
66970e2
primitives partially working
Sep 7, 2024
a5fdd04
pritives test runnable
Sep 10, 2024
782f759
symbolic primitives partially passed
Sep 11, 2024
254deb3
test all working
Sep 12, 2024
d685a57
test_tiago w/ transition_rule
Sep 13, 2024
61b629b
wip symbolic test
Sep 16, 2024
e54d544
most fixes addressed
Sep 16, 2024
d6263ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 16, 2024
c27f59e
two empty lines
Sep 16, 2024
e66c82d
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Sep 16, 2024
d2bfecd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 16, 2024
53277f4
symbolic primitive and primitive passing
Sep 19, 2024
2cf3902
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Sep 19, 2024
7dcdfd2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2024
8fe8ecf
control no op action
Sep 20, 2024
4f94198
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Sep 20, 2024
35cbf69
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2024
7af7086
Merge branch 'og-develop' of https://github.com/StanfordVL/OmniGibson…
Sep 20, 2024
db7b0bd
test fixed
Sep 21, 2024
3d3d8dc
test passed
Sep 21, 2024
1e98daa
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Sep 21, 2024
9bbeb99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2024
d91ea90
holonomic base wip
Sep 23, 2024
113ffe5
revived primitive and symbolic primitives
Sep 23, 2024
a9fdae5
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Sep 23, 2024
6e5d186
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2024
6cea6d6
reformat empty action
Sep 23, 2024
3a806cf
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Sep 23, 2024
83c1eea
reformat for github
Sep 23, 2024
b662788
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2024
2961164
empty action for no-op
Sep 24, 2024
ae49f5b
true no op action
Sep 24, 2024
527732c
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Sep 24, 2024
ebdd758
Merge branch 'og-develop' of https://github.com/StanfordVL/OmniGibson…
Sep 25, 2024
e41a766
primitives fixed except for symbolic primitives
Sep 25, 2024
9744531
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2024
ca1b56e
test
Sep 26, 2024
9680210
primitives minor update
Sep 26, 2024
36b896a
minor primitives test update
Sep 26, 2024
7e05790
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
7c29f5a
all tests passed
Sep 26, 2024
f816bb0
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Sep 26, 2024
809be5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
5425627
ready for merge
Sep 27, 2024
126254e
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Sep 27, 2024
8221732
test error resovled from IK controller
Sep 27, 2024
c583e43
multi finger gripper support added
Sep 27, 2024
f84d324
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2024
3f661e4
wip test controllers
Sep 30, 2024
22df2cd
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Sep 30, 2024
2f0734e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
0a4dc08
removed finger path
Sep 30, 2024
6e9f6f1
test without seeding
Sep 30, 2024
d145657
primitive test without seednig
Sep 30, 2024
aa1ec79
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
06c8bfc
test without seeding
Sep 30, 2024
97893eb
primitive test without seed
Sep 30, 2024
85a849e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
b5c254b
primitive with seeding
Oct 1, 2024
5f754ba
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Oct 1, 2024
0b245a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2024
79e0105
test with seeding
Oct 1, 2024
df38c80
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Oct 1, 2024
cc32e4e
Update usd_utils.py
yyf20001230 Oct 1, 2024
1954ad4
test controllers with relaxed bounds
Oct 1, 2024
eabc9f2
Merge branch 'curobo' of https://github.com/StanfordVL/OmniGibson int…
Oct 1, 2024
3232a45
Merge branch 'og-develop' into curobo
yyf20001230 Oct 1, 2024
893ad2c
revert back the tolerance relaxation, and give robots deterministic n…
ChengshuLi Oct 1, 2024
e88d4a4
Merge branch 'og-develop' into curobo
cgokmen Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 64 additions & 37 deletions omnigibson/action_primitives/starter_semantic_action_primitives.py
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
cgokmen marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
m.JOINT_POS_DIFF_THRESHOLD = 0.01
m.JOINT_CONTROL_MIN_ACTION = 0.0
m.MAX_ALLOWED_JOINT_ERROR_FOR_LINEAR_MOTION = math.radians(45)
m.TIME_WITHOUT_CHECKING = 1.0

log = create_module_logger(module_name=__name__)

Expand All @@ -116,8 +117,8 @@ def __init__(self):
self.relative_poses = {}
self.links_relative_poses = {}
self.reset_pose = {
"original": ([0, 0, -5.0], [0, 0, 0, 1]),
"simplified": ([5, 0, -5.0], [0, 0, 0, 1]),
"original": (th.tensor([0, 0, -5.0], dtype=th.float32), th.tensor([0, 0, 0, 1], dtype=th.float32)),
"simplified": (th.tensor([5, 0, -5.0], dtype=th.float32), th.tensor([0, 0, 0, 1], dtype=th.float32)),
}


Expand Down Expand Up @@ -164,14 +165,17 @@ def _assemble_robot_copy(self):
if m.TIAGO_TORSO_FIXED:
assert self.arm == "left", "Fixed torso mode only supports left arm!"
joint_control_idx = self.robot.arm_control_idx["left"]
joint_pos = th.tensor(self.robot.get_joint_positions()[joint_control_idx])
joint_pos = th.tensor(self.robot.get_joint_positions()[joint_control_idx], dtype=th.float32)
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
else:
joint_combined_idx = th.cat([self.robot.trunk_control_idx, self.robot.arm_control_idx[fk_descriptor]])
joint_pos = th.tensor(self.robot.get_joint_positions()[joint_combined_idx])
joint_pos = th.tensor(self.robot.get_joint_positions()[joint_combined_idx], dtype=th.float32)
link_poses = self.fk_solver.get_link_poses(joint_pos, arm_links)

# Set position of robot copy root prim
self._set_prim_pose(self.robot_copy.prims[self.robot_copy_type], self.robot.get_position_orientation())
# Set position of robot copy root prim as a tensor tuple
pos, orn = self.robot.get_position_orientation()
pos = th.tensor(pos, dtype=th.float32)
orn = th.tensor(orn, dtype=th.float32)
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
self._set_prim_pose(self.robot_copy.prims[self.robot_copy_type], (pos, orn))

# Assemble robot meshes
for link_name, meshes in self.robot_copy.meshes[self.robot_copy_type].items():
Expand All @@ -191,9 +195,9 @@ def _assemble_robot_copy(self):
self._set_prim_pose(copy_mesh, mesh_copy_pose)

def _set_prim_pose(self, prim, pose):
translation = lazy.pxr.Gf.Vec3d(*(th.tensor(pose[0], dtype=th.float32).tolist()))
translation = lazy.pxr.Gf.Vec3d(*pose[0].tolist())
prim.GetAttribute("xformOp:translate").Set(translation)
orientation = th.tensor(pose[1], dtype=th.float32)[[3, 0, 1, 2]]
orientation = pose[1][[3, 0, 1, 2]]
prim.GetAttribute("xformOp:orient").Set(lazy.pxr.Gf.Quatd(*orientation.tolist()))

def _construct_disabled_collision_pairs(self):
Expand Down Expand Up @@ -237,7 +241,7 @@ def _construct_disabled_collision_pairs(self):

# Disable original robot colliders so copy can't collide with it
disabled_colliders += [link.prim_path for link in self.robot.links.values()]
filter_categories = ["floors"]
filter_categories = ["floors", "carpet"]
for obj in self.env.scene.objects:
if obj.category in filter_categories:
disabled_colliders += [link.prim_path for link in obj.links.values()]
Expand Down Expand Up @@ -376,9 +380,9 @@ def _load_robot_copy(self):
lazy.omni.usd.commands.CreatePrimCommand("Xform", rc["copy_path"]).do()
copy_robot = lazy.omni.isaac.core.utils.prims.get_prim_at_path(rc["copy_path"])
reset_pose = robot_copy.reset_pose[robot_type]
translation = lazy.pxr.Gf.Vec3d(*th.tensor(reset_pose[0], dtype=th.float32).tolist())
translation = lazy.pxr.Gf.Vec3d(*reset_pose[0].tolist())
copy_robot.GetAttribute("xformOp:translate").Set(translation)
orientation = th.tensor(reset_pose[1], dtype=th.float32)[[3, 0, 1, 2]]
orientation = reset_pose[1][[3, 0, 1, 2]]
copy_robot.GetAttribute("xformOp:orient").Set(lazy.pxr.Gf.Quatd(*orientation.tolist()))

robot_to_copy = None
Expand Down Expand Up @@ -571,7 +575,6 @@ def _open_or_close(self, obj, should_open):

# If the grasp pose is too far, navigate
yield from self._navigate_if_needed(obj, pose_on_obj=grasp_pose)

yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
yield from self._move_hand(grasp_pose, stop_if_stuck=True)

# We can pre-grasp in sticky grasping mode only for opening
Expand Down Expand Up @@ -824,14 +827,12 @@ def _place_with_predicate(self, obj, predicate):
"""
# Update the tracking to track the object.
self._tracking_object = obj

yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
obj_in_hand = self._get_obj_in_hand()
if obj_in_hand is None:
raise ActionPrimitiveError(
ActionPrimitiveError.Reason.PRE_CONDITION_ERROR,
"You need to be grasping an object first to place it somewhere.",
)

yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
# Sample location to place object
obj_pose = self._sample_pose_with_object_and_predicate(predicate, obj_in_hand, obj)
hand_pose = self._get_hand_pose_for_object_pose(obj_pose)
Expand Down Expand Up @@ -991,17 +992,16 @@ def _move_hand_joint(self, joint_pos):
torso_fixed=m.TIAGO_TORSO_FIXED,
)

# plan = self._add_linearly_interpolated_waypoints(plan, 0.1)
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
if plan is None:
raise ActionPrimitiveError(
ActionPrimitiveError.Reason.PLANNING_ERROR,
"There is no accessible path from where you are to the desired joint position. Try again",
)

# Follow the plan to navigate.
indented_print("Plan has %d steps", len(plan))
indented_print(f"Plan has {len(plan)} steps")
for i, joint_pos in enumerate(plan):
indented_print("Executing grasp plan step %d/%d", i + 1, len(plan))
indented_print(f"Executing arm movement plan step {i + 1}/{len(plan)}")
yield from self._move_hand_direct_joint(joint_pos, ignore_failure=True)

def _move_hand_ik(self, eef_pose, stop_if_stuck=False):
Expand All @@ -1026,19 +1026,18 @@ def _move_hand_ik(self, eef_pose, stop_if_stuck=False):
torso_fixed=m.TIAGO_TORSO_FIXED,
)

# plan = self._add_linearly_interpolated_waypoints(plan, 0.1)
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
if plan is None:
raise ActionPrimitiveError(
ActionPrimitiveError.Reason.PLANNING_ERROR,
"There is no accessible path from where you are to the desired joint position. Try again",
)

# Follow the plan to navigate.
indented_print("Plan has %d steps", len(plan))
indented_print(f"Plan has {len(plan)} steps")
for i, target_pose in enumerate(plan):
target_pos = target_pose[:3]
target_quat = T.axisangle2quat(target_pose[3:])
indented_print("Executing grasp plan step %d/%d", i + 1, len(plan))
indented_print(f"Executing grasp plan step {i + 1}/{len(plan)}")
yield from self._move_hand_direct_ik(
(target_pos, target_quat), ignore_failure=True, in_world_frame=False, stop_if_stuck=stop_if_stuck
)
Expand Down Expand Up @@ -1081,22 +1080,23 @@ def _move_hand_direct_joint(self, joint_pos, stop_on_contact=False, ignore_failu
# Store the previous eef pose for checking if we got stuck
prev_eef_pos = th.zeros(3)

for _ in range(m.MAX_STEPS_FOR_HAND_MOVE_JOINT):
for i in range(m.MAX_STEPS_FOR_HAND_MOVE_JOINT):
current_joint_pos = self.robot.get_joint_positions()[self._manipulation_control_idx]
diff_joint_pos = th.tensor(joint_pos) - th.tensor(current_joint_pos)
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
if th.max(th.abs(diff_joint_pos)).item() < m.JOINT_POS_DIFF_THRESHOLD:
return
if stop_on_contact and detect_robot_collision_in_sim(self.robot, ignore_obj_in_hand=False):
return
if th.max(th.abs(self.robot.get_eef_position(self.arm) - prev_eef_pos)).item() < 0.0001:
# check if the eef stayed in the same pose for sufficiently long
if og.sim.get_sim_step_dt() * i > m.TIME_WITHOUT_CHECKING and th.max(th.abs(self.robot.get_eef_position(self.arm) - prev_eef_pos)).item() < 0.0001:
# We're stuck!
break

action = self._empty_action()
if use_delta:
action[self.robot.controller_action_idx[controller_name]] = diff_joint_pos
action[self.robot.controller_action_idx[controller_name]] = th.tensor(diff_joint_pos, dtype=th.float32)
else:
action[self.robot.controller_action_idx[controller_name]] = joint_pos
action[self.robot.controller_action_idx[controller_name]] = th.tensor(joint_pos, dtype=th.float32)
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved

prev_eef_pos = self.robot.get_eef_position(self.arm)
yield self._postprocess_action(action)
Expand Down Expand Up @@ -1205,8 +1205,10 @@ def _move_hand_linearly_cartesian(
# into 1cm-long pieces
start_pos, start_orn = self.robot.eef_links[self.arm].get_position_orientation()
travel_distance = th.norm(target_pose[0] - start_pos)
num_poses = th.max([2, int(travel_distance / m.MAX_CARTESIAN_HAND_STEP) + 1]).item()
pos_waypoints = th.linspace(start_pos, target_pose[0], num_poses)
num_poses = int(
th.max(th.tensor([2, int(travel_distance / m.MAX_CARTESIAN_HAND_STEP) + 1], dtype=th.float32)).item()
)
pos_waypoints = self.linspace_1d_tensor(start_pos, target_pose[0], num_poses)
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved

# Also interpolate the rotations
t_values = th.linspace(0, 1, num_poses)
Expand Down Expand Up @@ -1236,9 +1238,9 @@ def _move_hand_linearly_cartesian(

# Also decide if we can stop early.
current_pos, current_orn = self.robot.eef_links[self.arm].get_position_orientation()
pos_diff = th.norm(th.tensor(current_pos) - th.tensor(target_pose[0]))
pos_diff = th.norm(th.tensor(current_pos - target_pose[0], dtype=th.float32))
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
orn_diff = T.get_orientation_diff_in_radian(target_pose[1], current_orn).item()
if pos_diff < 0.005 and orn_diff < th.deg2rad(th.tensor([0.1])).item():
if pos_diff < 0.002 and orn_diff < th.deg2rad(th.tensor([0.1])).item():
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
return

if stop_on_contact and detect_robot_collision_in_sim(self.robot, ignore_obj_in_hand=False):
Expand Down Expand Up @@ -1432,6 +1434,10 @@ def _empty_action(self):
action_idx = self.robot.controller_action_idx[name]
no_op_goal = controller.compute_no_op_goal(self.robot.get_control_dict())

# TODO: wip solution for goal pose not consistent with action pose. if the controller uses delta motion, convert the no_op to all zeros
if self.robot._controller_config[name].get("use_delta_commands", True):
no_op_goal = {key: th.zeros_like(value) for key, value in no_op_goal.items()}
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved

if self.robot._controller_config[name]["name"] == "InverseKinematicsController":
assert (
self.robot._controller_config["arm_" + self.arm]["mode"] == "pose_absolute_ori"
Expand Down Expand Up @@ -1562,11 +1568,11 @@ def _navigate_to_pose(self, pose_2d):
"Could not make a navigation plan to get to the target position",
)

# #Follow the plan to navigate.
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
# self._draw_plan(plan)
# Follow the plan to navigate.
indented_print("Plan has %d steps", len(plan))
indented_print(f"Navigation plan has {len(plan)} steps")
for i, pose_2d in enumerate(plan):
indented_print("Executing navigation plan step %d/%d", i + 1, len(plan))
indented_print(f"Executing navigation plan step {i + 1}/{len(plan)}")
low_precision = True if i < len(plan) - 1 else False
yield from self._navigate_to_pose_direct(pose_2d, low_precision=low_precision)

Expand Down Expand Up @@ -1761,7 +1767,7 @@ def _sample_pose_near_object(self, obj, pose_on_obj=None, **kwargs):
distance_lo, distance_hi = 0.0, 5.0
distance = (th.rand(1) * (distance_hi - distance_lo) + distance_lo).item()
yaw_lo, yaw_hi = -math.pi, math.pi
yaw = (th.rand(1) * (yaw_hi - yaw_lo) + yaw_lo).item()
yaw = th.tensor((th.rand(1) * (yaw_hi - yaw_lo) + yaw_lo).item(), dtype=th.float32)
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
avg_arm_workspace_range = th.mean(self.robot.arm_workspace_range[self.arm])
pose_2d = th.tensor(
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
[
Expand All @@ -1783,6 +1789,7 @@ def _sample_pose_near_object(self, obj, pose_on_obj=None, **kwargs):
if not self._test_pose(pose_2d, context, pose_on_obj=pose_on_obj, **kwargs):
continue

indented_print("Found valid position near object.")
return pose_2d

raise ActionPrimitiveError(
Expand Down Expand Up @@ -1871,7 +1878,7 @@ def _sample_pose_with_object_and_predicate(

# Get the object pose by subtracting the offset
sampled_obj_pose = T.pose2mat((sampled_bb_center, sampled_bb_orn)) @ T.pose_inv(
T.pose2mat((bb_center_in_base, [0, 0, 0, 1]))
T.pose2mat((bb_center_in_base, th.tensor([0, 0, 0, 1], dtype=th.float32)))
)

# Check that the pose is near one of the poses in the near_poses list if provided.
Expand Down Expand Up @@ -1923,11 +1930,10 @@ def _get_robot_pose_from_2d_pose(pose_2d):
pose_2d (Iterable): (x, y, yaw) 2d pose

Returns:
2-tuple:
- 3-array: (x,y,z) Position in the world frame
- 4-array: (x,y,z,w) Quaternion orientation in the world frame
th.tensor: (x,y,z) Position in the world frame
th.tensor: (x,y,z,w) Quaternion orientation in the world frame
"""
pos = th.tensor([pose_2d[0], pose_2d[1], m.DEFAULT_BODY_OFFSET_FROM_FLOOR])
pos = th.tensor([pose_2d[0], pose_2d[1], m.DEFAULT_BODY_OFFSET_FROM_FLOOR], dtype=th.float32)
orn = T.euler2quat(th.tensor([0, 0, pose_2d[2]], dtype=th.float32))
return pos, orn

Expand Down Expand Up @@ -1991,3 +1997,24 @@ def _settle_robot(self):
break
empty_action = self._empty_action()
yield self._postprocess_action(empty_action)

def linspace_1d_tensor(self, start_pos, target_pose, num_poses):
"""
Create evenly spaced samples between two 1D tensors.

:param start_pos: Starting 1D tensor
:param target_pose: Ending 1D tensor
:param num_poses: Number of poses (samples) to generate
:return: Tensor of shape (num_poses, dim) where dim is the dimension of input tensors
"""
# Ensure inputs are 1D tensors
assert start_pos.dim() == 1 and target_pose.dim() == 1, "Input tensors must be 1D"
assert start_pos.shape == target_pose.shape, "Input tensors must have the same shape"

# Create a tensor of interpolation factors
t = th.linspace(0, 1, num_poses)

# Perform the interpolation
interpolated_points = start_pos.unsqueeze(0) + (target_pose - start_pos).unsqueeze(0) * t.unsqueeze(1)

return interpolated_points
yyf20001230 marked this conversation as resolved.
Show resolved Hide resolved
Loading