Skip to content

Commit

Permalink
add ability to condition on succeeding frames of a video
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 7, 2023
1 parent 3c24c60 commit 89cfde5
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 4 deletions.
2 changes: 2 additions & 0 deletions imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ def sample(
text_embeds = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
Expand Down Expand Up @@ -673,6 +674,7 @@ def sample(
text_mask = text_masks,
cond_images = cond_images,
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
inpaint_images = inpaint_images,
inpaint_masks = inpaint_masks,
inpaint_resample_times = inpaint_resample_times,
Expand Down
41 changes: 39 additions & 2 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,7 @@ def p_mean_variance(
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
lowres_cond_img = None,
self_cond = None,
lowres_noise_times = None,
Expand All @@ -2042,7 +2043,19 @@ def p_mean_variance(
):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'

pred = default(model_output, lambda: unet.forward_with_cond_scale(x, noise_scheduler.get_condition(t), text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_video_frames = cond_video_frames, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times)))
pred = default(model_output, lambda: unet.forward_with_cond_scale(
x,
noise_scheduler.get_condition(t),
text_embeds = text_embeds,
text_mask = text_mask,
cond_images = cond_images,
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
self_cond = self_cond,
lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times))
)

if pred_objective == 'noise':
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
Expand Down Expand Up @@ -2084,6 +2097,7 @@ def p_sample(
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
cond_scale = 1.,
self_cond = None,
lowres_cond_img = None,
Expand All @@ -2092,7 +2106,26 @@ def p_sample(
dynamic_threshold = True
):
b, *_, device = *x.shape, x.device
(model_mean, _, model_log_variance), x_start = self.p_mean_variance(unet, x = x, t = t, t_next = t_next, noise_scheduler = noise_scheduler, text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_video_frames = cond_video_frames, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_times = lowres_noise_times, pred_objective = pred_objective, dynamic_threshold = dynamic_threshold)

(model_mean, _, model_log_variance), x_start = self.p_mean_variance(
unet,
x = x,
t = t,
t_next = t_next,
noise_scheduler = noise_scheduler,
text_embeds = text_embeds,
text_mask = text_mask,
cond_images = cond_images,
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
self_cond = self_cond,
lowres_noise_times = lowres_noise_times,
pred_objective = pred_objective,
dynamic_threshold = dynamic_threshold
)

noise = torch.randn_like(x)
# no noise when t == 0
is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
Expand All @@ -2113,6 +2146,7 @@ def p_sample_loop(
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
Expand Down Expand Up @@ -2183,6 +2217,7 @@ def p_sample_loop(
text_mask = text_mask,
cond_images = cond_images,
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
cond_scale = cond_scale,
self_cond = self_cond,
lowres_cond_img = lowres_cond_img,
Expand Down Expand Up @@ -2222,6 +2257,7 @@ def sample(
video_frames = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
inpaint_images = None,
inpaint_masks = None,
inpaint_resample_times = 5,
Expand Down Expand Up @@ -2347,6 +2383,7 @@ def sample(
text_mask = text_masks,
cond_images = cond_images,
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
inpaint_images = inpaint_images,
inpaint_masks = inpaint_masks,
inpaint_resample_times = inpaint_resample_times,
Expand Down
22 changes: 21 additions & 1 deletion imagen_pytorch/imagen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,7 @@ def forward(
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
self_cond = None,
cond_drop_prob = 0.,
ignore_time = False
Expand Down Expand Up @@ -1627,8 +1628,11 @@ def forward(
lowres_cond_img = torch.cat((cond_video_frames, lowres_cond_img), dim = 2)
cond_video_frames = torch.cat((cond_video_frames, cond_video_frames), dim = 1)

if exists(post_cond_video_frames):
lowres_cond_img = torch.cat((lowres_cond_img, post_cond_video_frames), dim = 2)
post_cond_video_frames = torch.cat((post_cond_video_frames, post_cond_video_frames), dim = 1)

# conditioning on video frames as a prompt
# todo - add post_cond_video_frames as well

num_preceding_frames = 0
if exists(cond_video_frames):
Expand All @@ -1641,6 +1645,19 @@ def forward(

num_preceding_frames = cond_video_frames_len

# conditioning on video frames as a prompt

num_succeeding_frames = 0
if exists(post_cond_video_frames):
cond_video_frames_len = post_cond_video_frames.shape[2]

assert divisible_by(cond_video_frames_len, self.total_temporal_divisor)

post_cond_video_frames = resize_video_to(post_cond_video_frames, x.shape[-1])
x = torch.cat((post_cond_video_frames, x), dim = 2)

num_succeeding_frames = cond_video_frames_len

# condition on input image

assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
Expand Down Expand Up @@ -1853,4 +1870,7 @@ def forward(
if num_preceding_frames > 0:
out = out[:, :, num_preceding_frames:]

if num_succeeding_frames > 0:
out = out[:, :, -num_succeeding_frames:]

return out
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.21.2'
__version__ = '1.21.3'

0 comments on commit 89cfde5

Please sign in to comment.