From 52e73b2b6751fa214d9d7ba0061cefa481183f05 Mon Sep 17 00:00:00 2001 From: George Grigorev Date: Thu, 16 Feb 2023 21:40:30 +0800 Subject: [PATCH] add dpm_solver support --- ldm/models/diffusion/dpm_solver/dpm_solver.py | 25 +++++++++++++++++-- ldm/models/diffusion/dpm_solver/sampler.py | 8 +++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py index 095e5ba3ce..a336aba922 100644 --- a/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -307,7 +307,28 @@ def model_fn(x, t_continuous): else: x_in = torch.cat([x] * 2) t_in = torch.cat([t_continuous] * 2) - c_in = torch.cat([unconditional_condition, condition]) + + if isinstance(condition, dict): + assert isinstance(unconditional_condition, dict) + c_in = dict() + for k in condition: + if isinstance(condition[k], list): + c_in[k] = [torch.cat([ + unconditional_condition[k][i], + condition[k][i]]) for i in range(len(condition[k]))] + else: + c_in[k] = torch.cat([ + unconditional_condition[k], + condition[k]]) + elif isinstance(condition, list): + c_in = list() + assert isinstance(unconditional_condition, list) + for i in range(len(condition)): + c_in.append(torch.cat([unconditional_condition[i], condition[i]])) + else: + c_in = torch.cat([unconditional_condition, condition]) + +# c_in = torch.cat([unconditional_condition, condition]) noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) return noise_uncond + guidance_scale * (noise - noise_uncond) @@ -317,7 +338,7 @@ def model_fn(x, t_continuous): class DPM_Solver: - def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=True, max_val=1.): """Construct a DPM-Solver. We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8cf3..b3c4377204 100644 --- a/ldm/models/diffusion/dpm_solver/sampler.py +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -50,9 +50,15 @@ def sample(self, ): if conditioning is not None: if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")