forked from octo-models/octo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
04_eval_finetuned_on_robot.py
236 lines (194 loc) · 8.06 KB
/
04_eval_finetuned_on_robot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""
This script shows how we evaluated a finetuned Octo model on a real WidowX robot. While the exact specifics may not
be applicable to your use case, this script serves as a didactic example of how to use Octo in a real-world setting.
If you wish, you may reproduce these results by [reproducing the robot setup](https://rail-berkeley.github.io/bridgedata/)
and installing [the robot controller](https://github.com/rail-berkeley/bridge_data_robot)
"""
from datetime import datetime
from functools import partial
import os
import time
from absl import app, flags, logging
import click
import cv2
from envs.widowx_env import convert_obs, state_to_eep, wait_for_obs, WidowXGym
import imageio
import jax
import jax.numpy as jnp
import numpy as np
from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs, WidowXStatus
from octo.model.octo_model import OctoModel
from octo.utils.gym_wrappers import HistoryWrapper, TemporalEnsembleWrapper
from octo.utils.train_callbacks import supply_rng
np.set_printoptions(suppress=True)
logging.set_verbosity(logging.WARNING)
FLAGS = flags.FLAGS
flags.DEFINE_string(
"checkpoint_weights_path", None, "Path to checkpoint", required=True
)
flags.DEFINE_integer("checkpoint_step", None, "Checkpoint step", required=True)
# custom to bridge_data_robot
flags.DEFINE_string("ip", "localhost", "IP address of the robot")
flags.DEFINE_integer("port", 5556, "Port of the robot")
flags.DEFINE_spaceseplist("goal_eep", [0.3, 0.0, 0.15], "Goal position")
flags.DEFINE_spaceseplist("initial_eep", [0.3, 0.0, 0.15], "Initial position")
flags.DEFINE_bool("blocking", False, "Use the blocking controller")
flags.DEFINE_integer("im_size", None, "Image size", required=True)
flags.DEFINE_string("video_save_path", None, "Path to save video")
flags.DEFINE_integer("num_timesteps", 120, "num timesteps")
flags.DEFINE_integer("window_size", 2, "Observation history length")
flags.DEFINE_integer(
"action_horizon", 4, "Length of action sequence to execute/ensemble"
)
# show image flag
flags.DEFINE_bool("show_image", False, "Show image")
##############################################################################
STEP_DURATION_MESSAGE = """
Bridge data was collected with non-blocking control and a step duration of 0.2s.
However, we relabel the actions to make it look like the data was collected with
blocking control and we evaluate with blocking control.
Be sure to use a step duration of 0.2 if evaluating with non-blocking control.
"""
STEP_DURATION = 0.2
STICKY_GRIPPER_NUM_STEPS = 1
WORKSPACE_BOUNDS = [[0.1, -0.15, -0.01, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]]
CAMERA_TOPICS = [{"name": "/blue/image_raw"}]
ENV_PARAMS = {
"camera_topics": CAMERA_TOPICS,
"override_workspace_boundaries": WORKSPACE_BOUNDS,
"move_duration": STEP_DURATION,
}
##############################################################################
def main(_):
# set up the widowx client
if FLAGS.initial_eep is not None:
assert isinstance(FLAGS.initial_eep, list)
initial_eep = [float(e) for e in FLAGS.initial_eep]
start_state = np.concatenate([initial_eep, [0, 0, 0, 1]])
else:
start_state = None
env_params = WidowXConfigs.DefaultEnvParams.copy()
env_params.update(ENV_PARAMS)
env_params["start_state"] = list(start_state)
widowx_client = WidowXClient(host=FLAGS.ip, port=FLAGS.port)
widowx_client.init(env_params, image_size=FLAGS.im_size)
env = WidowXGym(
widowx_client, FLAGS.im_size, FLAGS.blocking, STICKY_GRIPPER_NUM_STEPS
)
if not FLAGS.blocking:
assert STEP_DURATION == 0.2, STEP_DURATION_MESSAGE
# load models
model = OctoModel.load_pretrained(
FLAGS.checkpoint_weights_path,
FLAGS.checkpoint_step,
)
# wrap the robot environment
env = HistoryWrapper(env, FLAGS.window_size)
env = TemporalEnsembleWrapper(env, FLAGS.action_horizon)
# switch TemporalEnsembleWrapper with RHCWrapper for receding horizon control
# env = RHCWrapper(env, FLAGS.action_horizon)
# create policy functions
def sample_actions(
pretrained_model: OctoModel,
observations,
tasks,
rng,
):
# add batch dim to observations
observations = jax.tree_map(lambda x: x[None], observations)
actions = pretrained_model.sample_actions(
observations,
tasks,
rng=rng,
unnormalization_statistics=pretrained_model.dataset_statistics[
"bridge_dataset"
]["action"],
)
# remove batch dim
return actions[0]
policy_fn = supply_rng(
partial(
sample_actions,
model,
argmax=FLAGS.deterministic,
temperature=FLAGS.temperature,
)
)
goal_image = jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8)
goal_instruction = ""
# goal sampling loop
while True:
modality = click.prompt(
"Language or goal image?", type=click.Choice(["l", "g"])
)
if modality == "g":
if click.confirm("Take a new goal?", default=True):
assert isinstance(FLAGS.goal_eep, list)
_eep = [float(e) for e in FLAGS.goal_eep]
goal_eep = state_to_eep(_eep, 0)
widowx_client.move_gripper(1.0) # open gripper
move_status = None
while move_status != WidowXStatus.SUCCESS:
move_status = widowx_client.move(goal_eep, duration=1.5)
input("Press [Enter] when ready for taking the goal image. ")
obs = wait_for_obs(widowx_client)
obs = convert_obs(obs, FLAGS.im_size)
goal = jax.tree_map(lambda x: x[None], obs)
# Format task for the model
task = model.create_tasks(goals=goal)
# For logging purposes
goal_image = goal["image_primary"][0]
goal_instruction = ""
elif modality == "l":
print("Current instruction: ", goal_instruction)
if click.confirm("Take a new instruction?", default=True):
text = input("Instruction?")
# Format task for the model
task = model.create_tasks(texts=[text])
# For logging purposes
goal_instruction = text
goal_image = jnp.zeros_like(goal_image)
else:
raise NotImplementedError()
input("Press [Enter] to start.")
# reset env
obs, _ = env.reset()
time.sleep(2.0)
# do rollout
last_tstep = time.time()
images = []
goals = []
t = 0
while t < FLAGS.num_timesteps:
if time.time() > last_tstep + STEP_DURATION:
last_tstep = time.time()
# save images
images.append(obs["image_primary"][-1])
goals.append(goal_image)
if FLAGS.show_image:
bgr_img = cv2.cvtColor(obs["image_primary"][-1], cv2.COLOR_RGB2BGR)
cv2.imshow("img_view", bgr_img)
cv2.waitKey(20)
# get action
forward_pass_time = time.time()
action = np.array(policy_fn(obs, task), dtype=np.float64)
print("forward pass time: ", time.time() - forward_pass_time)
# perform environment step
start_time = time.time()
obs, _, _, truncated, _ = env.step(action)
print("step time: ", time.time() - start_time)
t += 1
if truncated:
break
# save video
if FLAGS.video_save_path is not None:
os.makedirs(FLAGS.video_save_path, exist_ok=True)
curr_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
save_path = os.path.join(
FLAGS.video_save_path,
f"{curr_time}.mp4",
)
video = np.concatenate([np.stack(goals), np.stack(images)], axis=1)
imageio.mimsave(save_path, video, fps=1.0 / STEP_DURATION * 3)
if __name__ == "__main__":
app.run(main)