diff --git a/rslearn/train/callbacks/freeze_unfreeze.py b/rslearn/train/callbacks/freeze_unfreeze.py index d3f46a90..2f356262 100644 --- a/rslearn/train/callbacks/freeze_unfreeze.py +++ b/rslearn/train/callbacks/freeze_unfreeze.py @@ -3,6 +3,7 @@ import torch from lightning.pytorch import LightningModule from lightning.pytorch.callbacks import BaseFinetuning +from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.optimizer import Optimizer @@ -85,6 +86,22 @@ def finetune_function( "appending to ReduceLROnPlateau scheduler min_lrs for unfreeze" ) scheduler.min_lrs.append(scheduler.min_lrs[0]) + for callback in pl_module.trainer.callbacks: + if isinstance(callback, StochasticWeightAveraging): + if isinstance(callback._swa_lrs, list): + assert ( + len(callback._swa_lrs) == 1 + ), "only one swa lr is supported" + swa_lr = callback._swa_lrs[0] + elif isinstance(callback._swa_lrs, float): + swa_lr = callback._swa_lrs + else: + raise ValueError( + f"unknown swa lr type: {type(callback._swa_lrs)}" + ) + for param_group in optimizer.param_groups: + param_group["swa_lr"] = swa_lr + logger.info(f"setting all swa lrs to {swa_lr}") elif current_epoch > self.unfreeze_at_epoch: # always do this because overhead is minimal, and it allows restoring # from a checkpoint (resuming a run) without messing up unfreezing