-
Notifications
You must be signed in to change notification settings - Fork 43
/
rdt_runner.py
250 lines (216 loc) · 10.7 KB
/
rdt_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import re
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_dpmsolver_multistep import \
DPMSolverMultistepScheduler
from models.hub_mixin import CompatiblePyTorchModelHubMixin
from models.rdt.model import RDT
class RDTRunner(
nn.Module,
CompatiblePyTorchModelHubMixin,
repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b"
):
def __init__(self, *, action_dim, pred_horizon, config,
lang_token_dim, img_token_dim, state_token_dim,
max_lang_cond_len, img_cond_len, lang_pos_embed_config=None,
img_pos_embed_config=None, dtype=torch.bfloat16):
super(RDTRunner, self).__init__()
# Create diffusion model
hidden_size = config['rdt']['hidden_size']
self.model = RDT(
output_dim=action_dim,
horizon=pred_horizon,
hidden_size=hidden_size,
depth=config['rdt']['depth'],
num_heads=config['rdt']['num_heads'],
max_lang_cond_len=max_lang_cond_len,
img_cond_len=img_cond_len,
lang_pos_embed_config=lang_pos_embed_config,
img_pos_embed_config=img_pos_embed_config,
dtype=dtype,
)
# Create adpators for various conditional inputs
self.lang_adaptor = self.build_condition_adapter(
config['lang_adaptor'],
in_features=lang_token_dim,
out_features=hidden_size
)
self.img_adaptor = self.build_condition_adapter(
config['img_adaptor'],
in_features=img_token_dim,
out_features=hidden_size
)
# A `state` refers to an action or a proprioception vector
self.state_adaptor = self.build_condition_adapter(
config['state_adaptor'],
in_features=state_token_dim * 2, # state + state mask (indicator)
out_features=hidden_size
)
# Create the noise scheduler
noise_scheduler_config = config['noise_scheduler']
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
beta_schedule=noise_scheduler_config['beta_schedule'],
prediction_type=noise_scheduler_config['prediction_type'],
clip_sample=noise_scheduler_config['clip_sample'],
)
self.noise_scheduler_sample = DPMSolverMultistepScheduler(
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
beta_schedule=noise_scheduler_config['beta_schedule'],
prediction_type=noise_scheduler_config['prediction_type'],
)
self.num_train_timesteps = noise_scheduler_config['num_train_timesteps']
self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps']
self.prediction_type = noise_scheduler_config['prediction_type']
self.pred_horizon = pred_horizon
self.action_dim = action_dim
print("Diffusion params: %e" % sum(
[p.numel() for p in self.model.parameters()] +
[p.numel() for p in self.lang_adaptor.parameters()] +
[p.numel() for p in self.img_adaptor.parameters()] +
[p.numel() for p in self.state_adaptor.parameters()]))
def build_condition_adapter(
self, projector_type, in_features, out_features):
projector = None
if projector_type == 'linear':
projector = nn.Linear(in_features, out_features)
else:
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(in_features, out_features)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU(approximate="tanh"))
modules.append(nn.Linear(out_features, out_features))
projector = nn.Sequential(*modules)
if projector is None:
raise ValueError(f'Unknown projector type: {projector_type}')
return projector
def adapt_conditions(self, lang_tokens, img_tokens, state_tokens):
'''
lang_tokens: (batch_size, lang_len, lang_token_dim)
img_tokens: (batch_size, img_len, img_token_dim)
state_tokens: (batch_size, state_len, state_token_dim)
return: adpated (..., hidden_size) for all input tokens
'''
adpated_lang = self.lang_adaptor(lang_tokens)
adpated_img = self.img_adaptor(img_tokens)
adpated_state = self.state_adaptor(state_tokens)
return adpated_lang, adpated_img, adpated_state
def conditional_sample(self, lang_cond, lang_attn_mask, img_cond,
state_traj, action_mask, ctrl_freqs):
'''
lang_cond: language conditional data, (batch_size, lang_len, hidden_size).
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
which should be True-False bool tensor.
img_cond: image conditional data, (batch_size, img_len, hidden_size).
state_traj: (batch_size, 1, hidden_size), state trajectory.
action_mask: (batch_size, 1, action_dim), a 0-1 **float** tensor
indicating the valid action dimensions.
ctrl_freqs: (batch_size,), control frequency for each sample.
return: (batch_size, horizon, action_dim)
'''
device = state_traj.device
dtype = state_traj.dtype
noisy_action = torch.randn(
size=(state_traj.shape[0], self.pred_horizon, self.action_dim),
dtype=dtype, device=device)
action_mask = action_mask.expand(-1, self.pred_horizon, -1)
# Set step values
self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps)
for t in self.noise_scheduler_sample.timesteps:
# Prepare state-action trajectory
action_traj = torch.cat([noisy_action, action_mask], dim=2)
action_traj = self.state_adaptor(action_traj)
state_action_traj = torch.cat([state_traj, action_traj], dim=1)
# Predict the model output
model_output = self.model(state_action_traj, ctrl_freqs,
t.unsqueeze(-1).to(device),
lang_cond, img_cond, lang_mask=lang_attn_mask)
# Compute previous actions: x_t -> x_t-1
noisy_action = self.noise_scheduler_sample.step(
model_output, t, noisy_action).prev_sample
noisy_action = noisy_action.to(state_traj.dtype)
# Finally apply the action mask to mask invalid action dimensions
noisy_action = noisy_action * action_mask
return noisy_action
# ========= Train ============
def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens,
state_tokens, action_gt, action_mask, ctrl_freqs
) -> torch.Tensor:
'''
lang_tokens: (batch_size, lang_len, lang_token_dim)
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
which should be True-False bool tensor.
img_tokens: (batch_size, img_len, img_token_dim)
state_tokens: (batch_size, 1, state_token_dim)
action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision
action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor.
ctrl_freqs: (batch_size,), control frequency for each sample.
return: loss_value, a scalar tensor
'''
batch_size = lang_tokens.shape[0]
device = lang_tokens.device
# Sample noise that we'll add to the actions
noise = torch.randn(
action_gt.shape, dtype=action_gt.dtype, device=device
)
# Sample random diffusion timesteps
timesteps = torch.randint(
0, self.num_train_timesteps,
(batch_size,), device=device
).long()
# Add noise to the clean actions according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_action = self.noise_scheduler.add_noise(
action_gt, noise, timesteps)
# Concatenate the state and action tokens to form the input sequence
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1)
# Append the action mask to the input sequence
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1)
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2)
# Align the dimension with the hidden size
lang_cond, img_cond, state_action_traj = self.adapt_conditions(
lang_tokens, img_tokens, state_action_traj)
# Predict the denoised result
pred = self.model(state_action_traj, ctrl_freqs,
timesteps, lang_cond, img_cond,
lang_mask=lang_attn_mask)
pred_type = self.prediction_type
if pred_type == 'epsilon':
target = noise
elif pred_type == 'sample':
target = action_gt
else:
raise ValueError(f"Unsupported prediction type {pred_type}")
loss = F.mse_loss(pred, target)
return loss
# ========= Inference ============
def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens,
action_mask, ctrl_freqs):
'''
lang_tokens: (batch_size, lang_len, lang_token_dim)
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
which should be True-False bool tensor.
img_tokens: (batch_size, img_len, img_token_dim)
state_tokens: (batch_size, 1, state_token_dim)
action_mask: (batch_size, 1, action_dim),
which should be a 0-1 **float** tensor.
ctrl_freqs: (batch_size,), control frequency for each sample.
return: (batch_size, horizon, action_dim), predicted action sequence
'''
# Prepare the state and conditions
state_tokens = torch.cat([state_tokens, action_mask], dim=2)
lang_cond, img_cond, state_traj = self.adapt_conditions(
lang_tokens, img_tokens, state_tokens)
# Run sampling
action_pred = self.conditional_sample(
lang_cond, lang_attn_mask, img_cond,
state_traj, action_mask, ctrl_freqs,
)
return action_pred
def forward(self, *args, **kwargs) -> torch.Tensor:
return self.compute_loss(*args, **kwargs)