Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions rslearn/train/callbacks/freeze_unfreeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from lightning.pytorch import LightningModule
from lightning.pytorch.callbacks import BaseFinetuning
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.optimizer import Optimizer

Expand Down Expand Up @@ -85,6 +86,22 @@ def finetune_function(
"appending to ReduceLROnPlateau scheduler min_lrs for unfreeze"
)
scheduler.min_lrs.append(scheduler.min_lrs[0])
for callback in pl_module.trainer.callbacks:
if isinstance(callback, StochasticWeightAveraging):
if isinstance(callback._swa_lrs, list):
assert (
len(callback._swa_lrs) == 1
), "only one swa lr is supported"
swa_lr = callback._swa_lrs[0]
elif isinstance(callback._swa_lrs, float):
swa_lr = callback._swa_lrs
else:
raise ValueError(
f"unknown swa lr type: {type(callback._swa_lrs)}"
)
for param_group in optimizer.param_groups:
param_group["swa_lr"] = swa_lr
logger.info(f"setting all swa lrs to {swa_lr}")
elif current_epoch > self.unfreeze_at_epoch:
# always do this because overhead is minimal, and it allows restoring
# from a checkpoint (resuming a run) without messing up unfreezing
Expand Down
Loading