-
Notifications
You must be signed in to change notification settings - Fork 144
/
generate.py
199 lines (176 loc) · 6.62 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""
References:
- Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing
"""
import torch
from dit import DiT_models
from vae import VAE_models
from torchvision.io import read_video, write_video
from utils import load_prompt, load_actions, sigmoid_beta_schedule
from tqdm import tqdm
from einops import rearrange
from torch import autocast
from safetensors.torch import load_model
import argparse
from pprint import pprint
import os
assert torch.cuda.is_available()
device = "cuda:0"
def main(args):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# load DiT checkpoint
model = DiT_models["DiT-S/2"]()
print(f"loading Oasis-500M from oasis-ckpt={os.path.abspath(args.oasis_ckpt)}...")
if args.oasis_ckpt.endswith(".pt"):
ckpt = torch.load(args.oasis_ckpt, weights_only=True)
model.load_state_dict(ckpt, strict=False)
elif args.oasis_ckpt.endswith(".safetensors"):
load_model(model, args.oasis_ckpt)
model = model.to(device).eval()
# load VAE checkpoint
vae = VAE_models["vit-l-20-shallow-encoder"]()
print(f"loading ViT-VAE-L/20 from vae-ckpt={os.path.abspath(args.vae_ckpt)}...")
if args.vae_ckpt.endswith(".pt"):
vae_ckpt = torch.load(args.vae_ckpt, weights_only=True)
vae.load_state_dict(vae_ckpt)
elif args.vae_ckpt.endswith(".safetensors"):
load_model(vae, args.vae_ckpt)
vae = vae.to(device).eval()
# sampling params
n_prompt_frames = args.n_prompt_frames
total_frames = args.num_frames
max_noise_level = 1000
ddim_noise_steps = args.ddim_steps
noise_range = torch.linspace(-1, max_noise_level - 1, ddim_noise_steps + 1)
noise_abs_max = 20
stabilization_level = 15
# get prompt image/video
x = load_prompt(
args.prompt_path,
video_offset=args.video_offset,
n_prompt_frames=n_prompt_frames,
)
# get input action stream
actions = load_actions(args.actions_path, action_offset=args.video_offset)[:, :total_frames]
# sampling inputs
x = x.to(device)
actions = actions.to(device)
# vae encoding
B = x.shape[0]
H, W = x.shape[-2:]
scaling_factor = 0.07843137255
x = rearrange(x, "b t c h w -> (b t) c h w")
with torch.no_grad():
with autocast("cuda", dtype=torch.half):
x = vae.encode(x * 2 - 1).mean * scaling_factor
x = rearrange(x, "(b t) (h w) c -> b t c h w", t=n_prompt_frames, h=H // vae.patch_size, w=W // vae.patch_size)
x = x[:, :n_prompt_frames]
# get alphas
betas = sigmoid_beta_schedule(max_noise_level).float().to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod = rearrange(alphas_cumprod, "T -> T 1 1 1")
# sampling loop
for i in tqdm(range(n_prompt_frames, total_frames)):
chunk = torch.randn((B, 1, *x.shape[-3:]), device=device)
chunk = torch.clamp(chunk, -noise_abs_max, +noise_abs_max)
x = torch.cat([x, chunk], dim=1)
start_frame = max(0, i + 1 - model.max_frames)
for noise_idx in reversed(range(1, ddim_noise_steps + 1)):
# set up noise values
t_ctx = torch.full((B, i), stabilization_level - 1, dtype=torch.long, device=device)
t = torch.full((B, 1), noise_range[noise_idx], dtype=torch.long, device=device)
t_next = torch.full((B, 1), noise_range[noise_idx - 1], dtype=torch.long, device=device)
t_next = torch.where(t_next < 0, t, t_next)
t = torch.cat([t_ctx, t], dim=1)
t_next = torch.cat([t_ctx, t_next], dim=1)
# sliding window
x_curr = x.clone()
x_curr = x_curr[:, start_frame:]
t = t[:, start_frame:]
t_next = t_next[:, start_frame:]
# get model predictions
with torch.no_grad():
with autocast("cuda", dtype=torch.half):
v = model(x_curr, t, actions[:, start_frame : i + 1])
x_start = alphas_cumprod[t].sqrt() * x_curr - (1 - alphas_cumprod[t]).sqrt() * v
x_noise = ((1 / alphas_cumprod[t]).sqrt() * x_curr - x_start) / (1 / alphas_cumprod[t] - 1).sqrt()
# get frame prediction
alpha_next = alphas_cumprod[t_next]
alpha_next[:, :-1] = torch.ones_like(alpha_next[:, :-1])
if noise_idx == 1:
alpha_next[:, -1:] = torch.ones_like(alpha_next[:, -1:])
x_pred = alpha_next.sqrt() * x_start + x_noise * (1 - alpha_next).sqrt()
x[:, -1:] = x_pred[:, -1:]
# vae decoding
x = rearrange(x, "b t c h w -> (b t) (h w) c")
with torch.no_grad():
x = (vae.decode(x / scaling_factor) + 1) / 2
x = rearrange(x, "(b t) c h w -> b t h w c", t=total_frames)
# save video
x = torch.clamp(x, 0, 1)
x = (x * 255).byte()
write_video(args.output_path, x[0].cpu(), fps=args.fps)
print(f"generation saved to {args.output_path}.")
if __name__ == "__main__":
parse = argparse.ArgumentParser()
parse.add_argument(
"--oasis-ckpt",
type=str,
help="Path to Oasis DiT checkpoint.",
default="oasis500m.safetensors",
)
parse.add_argument(
"--vae-ckpt",
type=str,
help="Path to Oasis ViT-VAE checkpoint.",
default="vit-l-20.safetensors",
)
parse.add_argument(
"--num-frames",
type=int,
help="How many frames should the output be?",
default=32,
)
parse.add_argument(
"--prompt-path",
type=str,
help="Path to image or video to condition generation on.",
default="sample_data/sample_image_0.png",
)
parse.add_argument(
"--actions-path",
type=str,
help="File to load actions from (.actions.pt or .one_hot_actions.pt)",
default="sample_data/sample_actions_0.one_hot_actions.pt",
)
parse.add_argument(
"--video-offset",
type=int,
help="If loading prompt from video, index of frame to start reading from.",
default=None,
)
parse.add_argument(
"--n-prompt-frames",
type=int,
help="If the prompt is a video, how many frames to condition on.",
default=1,
)
parse.add_argument(
"--output-path",
type=str,
help="Path where generated video should be saved.",
default="video.mp4",
)
parse.add_argument(
"--fps",
type=int,
help="What framerate should be used to save the output?",
default=20,
)
parse.add_argument("--ddim-steps", type=int, help="How many DDIM steps?", default=10)
args = parse.parse_args()
print("inference args:")
pprint(vars(args))
main(args)