39
39
SPADEAutoencoderKL ,
40
40
SPADEDiffusionModelUNet ,
41
41
)
42
- from monai .networks .schedulers import Scheduler
42
+ from monai .networks .schedulers import RFlowScheduler , Scheduler
43
43
from monai .transforms import CenterSpatialCrop , SpatialPad
44
44
from monai .utils import BlendMode , Ordering , PatchKeys , PytorchPadMode , ensure_tuple , optional_import
45
45
from monai .visualize import CAM , GradCAM , GradCAMpp
@@ -859,12 +859,18 @@ def sample(
859
859
if not scheduler :
860
860
scheduler = self .scheduler
861
861
image = input_noise
862
+
863
+ all_next_timesteps = torch .cat ((scheduler .timesteps [1 :], torch .tensor ([0 ], dtype = scheduler .timesteps .dtype )))
862
864
if verbose and has_tqdm :
863
- progress_bar = tqdm (scheduler .timesteps )
865
+ progress_bar = tqdm (
866
+ zip (scheduler .timesteps , all_next_timesteps ),
867
+ total = min (len (scheduler .timesteps ), len (all_next_timesteps )),
868
+ )
864
869
else :
865
- progress_bar = iter (scheduler .timesteps )
870
+ progress_bar = iter (zip ( scheduler .timesteps , all_next_timesteps ) )
866
871
intermediates = []
867
- for t in progress_bar :
872
+
873
+ for t , next_t in progress_bar :
868
874
# 1. predict noise model_output
869
875
diffusion_model = (
870
876
partial (diffusion_model , seg = seg )
@@ -882,9 +888,13 @@ def sample(
882
888
)
883
889
884
890
# 2. compute previous image: x_t -> x_t-1
885
- image , _ = scheduler .step (model_output , t , image ) # type: ignore[operator]
891
+ if not isinstance (scheduler , RFlowScheduler ):
892
+ image , _ = scheduler .step (model_output , t , image ) # type: ignore
893
+ else :
894
+ image , _ = scheduler .step (model_output , t , image , next_t ) # type: ignore
886
895
if save_intermediates and t % intermediate_steps == 0 :
887
896
intermediates .append (image )
897
+
888
898
if save_intermediates :
889
899
return image , intermediates
890
900
else :
@@ -1392,12 +1402,18 @@ def sample( # type: ignore[override]
1392
1402
if not scheduler :
1393
1403
scheduler = self .scheduler
1394
1404
image = input_noise
1405
+
1406
+ all_next_timesteps = torch .cat ((scheduler .timesteps [1 :], torch .tensor ([0 ], dtype = scheduler .timesteps .dtype )))
1395
1407
if verbose and has_tqdm :
1396
- progress_bar = tqdm (scheduler .timesteps )
1408
+ progress_bar = tqdm (
1409
+ zip (scheduler .timesteps , all_next_timesteps ),
1410
+ total = min (len (scheduler .timesteps ), len (all_next_timesteps )),
1411
+ )
1397
1412
else :
1398
- progress_bar = iter (scheduler .timesteps )
1413
+ progress_bar = iter (zip ( scheduler .timesteps , all_next_timesteps ) )
1399
1414
intermediates = []
1400
- for t in progress_bar :
1415
+
1416
+ for t , next_t in progress_bar :
1401
1417
diffuse = diffusion_model
1402
1418
if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
1403
1419
diffuse = partial (diffusion_model , seg = seg )
@@ -1436,7 +1452,11 @@ def sample( # type: ignore[override]
1436
1452
)
1437
1453
1438
1454
# 3. compute previous image: x_t -> x_t-1
1439
- image , _ = scheduler .step (model_output , t , image ) # type: ignore[operator]
1455
+ if not isinstance (scheduler , RFlowScheduler ):
1456
+ image , _ = scheduler .step (model_output , t , image ) # type: ignore
1457
+ else :
1458
+ image , _ = scheduler .step (model_output , t , image , next_t ) # type: ignore
1459
+
1440
1460
if save_intermediates and t % intermediate_steps == 0 :
1441
1461
intermediates .append (image )
1442
1462
if save_intermediates :
0 commit comments