Skip to content

Commit 90de55b

Browse files
Can-Zhaoyiheng-wang-nvpooya-mohammadiKumoLiuericspod
authored
Add rectified flow noise scheduler for accelerated diffusion model (#8374)
Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <[email protected]> Signed-off-by: Can-Zhao <[email protected]> Signed-off-by: pooya-mohammadi <[email protected]> Signed-off-by: YunLiu <[email protected]> Signed-off-by: advcu987 <[email protected]> Signed-off-by: advcu <[email protected]> Signed-off-by: Eric Kerfoot <[email protected]> Signed-off-by: Can-Zhao <[email protected]> Signed-off-by: R. Garcia-Dias <[email protected]> Signed-off-by: Rafael Garcia-Dias <[email protected]> Signed-off-by: Eric Kerfoot <[email protected]> Signed-off-by: James Butler <[email protected]> Signed-off-by: Virginia Fernandez <[email protected]> Signed-off-by: Nicolas Kaenzig <[email protected]> Signed-off-by: Bartosz Grabowski <[email protected]> Signed-off-by: thibaultdvx <[email protected]> Signed-off-by: Thibault de Varax <[email protected]> Signed-off-by: monai-bot <[email protected]> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: Pooya Mohammadi Kazaj <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: advcu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rafael Garcia-Dias <[email protected]> Co-authored-by: Rafael Garcia-Dias <[email protected]> Co-authored-by: Ben Murray <[email protected]> Co-authored-by: James Butler <[email protected]> Co-authored-by: Virginia Fernandez <[email protected]> Co-authored-by: Virginia Fernandez <[email protected]> Co-authored-by: Nicolas Känzig <[email protected]> Co-authored-by: Bartosz Grabowski <[email protected]> Co-authored-by: Thibault de Varax <[email protected]> Co-authored-by: monai-bot <[email protected]>
1 parent 7c26e5a commit 90de55b

File tree

9 files changed

+892
-253
lines changed

9 files changed

+892
-253
lines changed

docs/source/networks.rst

+35
Original file line numberDiff line numberDiff line change
@@ -750,3 +750,38 @@ Utilities
750750

751751
.. automodule:: monai.apps.reconstruction.networks.nets.utils
752752
:members:
753+
754+
Noise Schedulers
755+
----------------
756+
.. automodule:: monai.networks.schedulers
757+
.. currentmodule:: monai.networks.schedulers
758+
759+
`Scheduler`
760+
~~~~~~~~~~~
761+
.. autoclass:: Scheduler
762+
:members:
763+
764+
`NoiseSchedules`
765+
~~~~~~~~~~~~~~~~
766+
.. autoclass:: NoiseSchedules
767+
:members:
768+
769+
`DDPMScheduler`
770+
~~~~~~~~~~~~~~~
771+
.. autoclass:: DDPMScheduler
772+
:members:
773+
774+
`DDIMScheduler`
775+
~~~~~~~~~~~~~~~
776+
.. autoclass:: DDIMScheduler
777+
:members:
778+
779+
`PNDMScheduler`
780+
~~~~~~~~~~~~~~~
781+
.. autoclass:: PNDMScheduler
782+
:members:
783+
784+
`RFlowScheduler`
785+
~~~~~~~~~~~~~~~~
786+
.. autoclass:: RFlowScheduler
787+
:members:

monai/apps/generation/maisi/networks/autoencoderkl_maisi.py

+4
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
232232
if self.print_info:
233233
logger.info(f"Number of splits: {self.num_splits}")
234234

235+
if self.dim_split <= 1 and self.num_splits <= 1:
236+
x = self.conv(x)
237+
return x
238+
235239
# compute size of splits
236240
l = x.size(self.dim_split + 2)
237241
split_size = l // self.num_splits

monai/inferers/inferer.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
SPADEAutoencoderKL,
4040
SPADEDiffusionModelUNet,
4141
)
42-
from monai.networks.schedulers import Scheduler
42+
from monai.networks.schedulers import RFlowScheduler, Scheduler
4343
from monai.transforms import CenterSpatialCrop, SpatialPad
4444
from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import
4545
from monai.visualize import CAM, GradCAM, GradCAMpp
@@ -859,12 +859,18 @@ def sample(
859859
if not scheduler:
860860
scheduler = self.scheduler
861861
image = input_noise
862+
863+
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
862864
if verbose and has_tqdm:
863-
progress_bar = tqdm(scheduler.timesteps)
865+
progress_bar = tqdm(
866+
zip(scheduler.timesteps, all_next_timesteps),
867+
total=min(len(scheduler.timesteps), len(all_next_timesteps)),
868+
)
864869
else:
865-
progress_bar = iter(scheduler.timesteps)
870+
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
866871
intermediates = []
867-
for t in progress_bar:
872+
873+
for t, next_t in progress_bar:
868874
# 1. predict noise model_output
869875
diffusion_model = (
870876
partial(diffusion_model, seg=seg)
@@ -882,9 +888,13 @@ def sample(
882888
)
883889

884890
# 2. compute previous image: x_t -> x_t-1
885-
image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
891+
if not isinstance(scheduler, RFlowScheduler):
892+
image, _ = scheduler.step(model_output, t, image) # type: ignore
893+
else:
894+
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
886895
if save_intermediates and t % intermediate_steps == 0:
887896
intermediates.append(image)
897+
888898
if save_intermediates:
889899
return image, intermediates
890900
else:
@@ -1392,12 +1402,18 @@ def sample( # type: ignore[override]
13921402
if not scheduler:
13931403
scheduler = self.scheduler
13941404
image = input_noise
1405+
1406+
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
13951407
if verbose and has_tqdm:
1396-
progress_bar = tqdm(scheduler.timesteps)
1408+
progress_bar = tqdm(
1409+
zip(scheduler.timesteps, all_next_timesteps),
1410+
total=min(len(scheduler.timesteps), len(all_next_timesteps)),
1411+
)
13971412
else:
1398-
progress_bar = iter(scheduler.timesteps)
1413+
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
13991414
intermediates = []
1400-
for t in progress_bar:
1415+
1416+
for t, next_t in progress_bar:
14011417
diffuse = diffusion_model
14021418
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14031419
diffuse = partial(diffusion_model, seg=seg)
@@ -1436,7 +1452,11 @@ def sample( # type: ignore[override]
14361452
)
14371453

14381454
# 3. compute previous image: x_t -> x_t-1
1439-
image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
1455+
if not isinstance(scheduler, RFlowScheduler):
1456+
image, _ = scheduler.step(model_output, t, image) # type: ignore
1457+
else:
1458+
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
1459+
14401460
if save_intermediates and t % intermediate_steps == 0:
14411461
intermediates.append(image)
14421462
if save_intermediates:

monai/networks/schedulers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
from .ddim import DDIMScheduler
1515
from .ddpm import DDPMScheduler
1616
from .pndm import PNDMScheduler
17+
from .rectified_flow import RFlowScheduler
1718
from .scheduler import NoiseSchedules, Scheduler

0 commit comments

Comments
 (0)