diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 85bfb65c0ea6e..f771304a525c7 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -36,12 +36,20 @@ from typing_extensions import override import lightning.pytorch as pl -from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem -from lightning.fabric.utilities.types import _PATH from lightning.pytorch.callbacks import Checkpoint from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn +from lightning.pytorch.utilities.rank_zero import ( + WarningCache, + rank_zero_info, + rank_zero_warn, +) from lightning.pytorch.utilities.types import STEP_OUTPUT +from lightning_fabric.utilities.cloud_io import ( + _is_dir, + _is_local_file_protocol, + get_filesystem, +) +from lightning_fabric.utilities.types import _PATH log = logging.getLogger(__name__) warning_cache = WarningCache() @@ -244,6 +252,7 @@ def __init__( self.best_k_models: dict[str, Tensor] = {} self.kth_best_model_path = "" self.best_model_score: Optional[Tensor] = None + self.best_model_metrics: Optional[dict[str, Tensor]] = None self.best_model_path = "" self.last_model_path = "" self._last_checkpoint_saved = "" @@ -339,6 +348,7 @@ def state_dict(self) -> dict[str, Any]: return { "monitor": self.monitor, "best_model_score": self.best_model_score, + "best_model_metrics": self.best_model_metrics, "best_model_path": self.best_model_path, "current_score": self.current_score, "dirpath": self.dirpath, @@ -354,6 +364,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.dirpath == dirpath_from_ckpt: self.best_model_score = state_dict["best_model_score"] + self.best_model_metrics = state_dict["best_model_metrics"] self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path) self.kth_value = state_dict.get("kth_value", self.kth_value) self.best_k_models = state_dict.get("best_k_models", self.best_k_models) @@ -361,8 +372,8 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: else: warnings.warn( f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r}," - " therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and" - " `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded." + " therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path`," + " `best_k_models` and `best_model_metrics` won't be reloaded. Only `best_model_path` will be reloaded." ) self.best_model_path = state_dict["best_model_path"] @@ -746,6 +757,8 @@ def _update_best_and_save( _op = min if self.mode == "min" else max self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] self.best_model_score = self.best_k_models[self.best_model_path] + if self.best_model_path == filepath: + self.best_model_metrics = monitor_candidates if self.verbose: epoch = monitor_candidates["epoch"] diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 1907a5fb35799..c0e4c3f2c4f2b 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -32,7 +32,6 @@ from torch import optim from torch.utils.data.dataloader import DataLoader -import lightning.pytorch as pl from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint @@ -703,6 +702,7 @@ def test_model_checkpoint_save_last_none_monitor(tmp_path, caplog): assert checkpoint_callback.best_model_path == str(tmp_path / "epoch=1-step=20.ckpt") assert checkpoint_callback.last_model_path == str(tmp_path / "last.ckpt") assert checkpoint_callback.best_model_score is None + assert checkpoint_callback.best_model_metrics is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == "" @@ -809,6 +809,7 @@ def test_model_checkpoint_topk_zero(tmp_path): assert checkpoint_callback.monitor is None assert checkpoint_callback.best_model_path == "" assert checkpoint_callback.best_model_score is None + assert checkpoint_callback.best_model_metrics is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == "" # check that only the last ckpt was created @@ -1074,7 +1075,7 @@ def assert_checkpoint_log_dir(idx): # load from checkpoint trainer_config["logger"] = TensorBoardLogger(tmp_path) - trainer = pl.Trainer(**trainer_config) + trainer = Trainer(**trainer_config) assert_trainer_init(trainer) model = ExtendedBoringModel()