Skip to content

fix checkpointing and validation with InfoNCE#26

Open
hummuscience wants to merge 1 commit intopaninski-lab:mainfrom
hummuscience:fix/checkpointing-and-validation
Open

fix checkpointing and validation with InfoNCE#26
hummuscience wants to merge 1 commit intopaninski-lab:mainfrom
hummuscience:fix/checkpointing-and-validation

Conversation

@hummuscience
Copy link
Copy Markdown

Same as what was described here: #21

Three issues fixed:

  1. ContrastBatchSampler bug: pos_indices from extract_anchor_indices are
    local (subset) indices, but were compared against global dataset_indices
    in the valid_positives filter. This caused all positives to be rejected
    for non-zero-based subsets (val/test), producing empty batches. Removed
    the incorrect check since local indices are valid by construction.

  2. val_dataloader now uses ContrastBatchSampler and contrastive_collate_fn
    when use_sampler=True, matching the training dataloader. This enables
    proper InfoNCE computation during validation, so val_loss is logged and
    the best-model checkpoint callback can trigger.

  3. Add save_last=True to ModelCheckpoint so a last.ckpt is always
    maintained, preventing total loss of training progress on interruption.

@hummuscience
Copy link
Copy Markdown
Author

Hmmm, I am not sure this fixes things. The checkpoint saving depends on the validation working (unless one activates saving checkpoints every N epochs).

But somehow I am not getting validation losses while training. Will try to dig deeper

Four issues fixed:

1. ContrastBatchSampler bug: pos_indices from extract_anchor_indices are
   local (subset) indices, but were compared against global dataset_indices
   in the valid_positives filter. This caused all positives to be rejected
   for non-zero-based subsets (val/test), producing empty batches. Removed
   the incorrect check since local indices are valid by construction.

2. val_dataloader now uses ContrastBatchSampler and contrastive_collate_fn
   when use_sampler=True, matching the training dataloader. Both dataloaders
   use batch_sampler= (not sampler=) since ContrastBatchSampler yields
   batches of indices, not individual indices.

3. Add save_last=True to ModelCheckpoint so a last.ckpt is always
   maintained, preventing total loss of training progress on interruption.

4. Fix validation never triggering with ContrastBatchSampler. Even with
   batch_sampler=, Lightning 2.6.1 sets max_batches=inf when
   use_distributed_sampler=False, preventing epoch-boundary validation.
   Workaround: use step-based val_check_interval (dataloader_len *
   check_val_every_n_epoch) with check_val_every_n_epoch=None.
   Also make num_sanity_val_steps configurable (default: 2).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@hummuscience hummuscience force-pushed the fix/checkpointing-and-validation branch from dcf9b68 to 8571d61 Compare March 28, 2026 13:30
@hummuscience
Copy link
Copy Markdown
Author

this took a while to dig through, I must admit after a while I couldnt follow where it was going since I am not familiar with the details but the robots seem to agree. here is a repro script

"""Reproduce Lightning 2.6.x validation bug in the BEAST training pipeline.

Shows that validation never triggers when using ContrastBatchSampler with
`sampler=` + `batch_size=None` through the `beast.train.train()` code path.

The bug does NOT manifest with a bare `trainer.fit()` call — it requires
the full BEAST code path (Model.train → train() → Trainer.fit with
datamodule, chdir, and TensorBoardLogger).

Demonstrates three scenarios:
  1. BUGGY:  Original BEAST code (sampler= + batch_size=None)
  2. FIX 1:  Switch to batch_sampler= (correct PyTorch API)
  3. FIX 2:  Fix 1 + step-based val_check_interval (belt and suspenders)

Usage:
    CUDA_VISIBLE_DEVICES="" python scripts/repro_lightning_val_bug.py

Requires: beast-backbones installed (editable), no GPU needed.
"""

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import copy
import shutil
import tempfile
from pathlib import Path

import yaml
import torch
import numpy as np
import lightning.pytorch as pl
from lightning.pytorch.utilities import rank_zero_only
from PIL import Image

from beast.data.augmentations import imgaug_pipeline, expand_imgaug_str_to_dict
from beast.data.datasets import BaseDataset
from beast.data.datamodules import BaseDataModule
from beast.data.samplers import ContrastBatchSampler, contrastive_collate_fn
from beast.models.vits import VisionTransformer
from beast.train import get_callbacks

# ---------------------------------------------------------------------------
# Create a small test dataset (5 fake videos, 100 frames each)
# ---------------------------------------------------------------------------

def create_test_data(root: Path, n_videos: int = 5, n_frames: int = 100):
    """Create a minimal frame directory structure for BEAST."""
    root.mkdir(parents=True, exist_ok=True)
    for v in range(n_videos):
        vid_dir = root / f"video_{v:03d}"
        vid_dir.mkdir(exist_ok=True)
        for f in range(n_frames):
            img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
            Image.fromarray(img).save(vid_dir / f"frame_{f:04d}.png")
    return root


# ---------------------------------------------------------------------------
# Three training functions that mirror beast/train.py but with variations
# ---------------------------------------------------------------------------

def _make_config():
    """Minimal BEAST config for testing."""
    return {
        "model": {
            "seed": 0,
            "checkpoint": None,
            "model_class": "vit",
            "model_params": {
                "hidden_size": 768, "num_hidden_layers": 2,
                "num_attention_heads": 2, "intermediate_size": 512,
                "hidden_act": "gelu", "hidden_dropout_prob": 0.0,
                "attention_probs_dropout_prob": 0.0, "initializer_range": 0.02,
                "layer_norm_eps": 1e-12, "image_size": 224, "patch_size": 16,
                "num_channels": 3, "qkv_bias": True,
                "decoder_num_attention_heads": 2, "decoder_hidden_size": 128,
                "decoder_num_hidden_layers": 2, "decoder_intermediate_size": 256,
                "mask_ratio": 0.75, "norm_pix_loss": False,
                "embed_size": 768, "temp_scale": False,
                "random_init": True,  # skip HuggingFace download
                "use_infoNCE": True, "infoNCE_weight": 0.01,
                "use_perceptual_loss": False, "lambda_perceptual": 10.0,
            },
        },
        "training": {
            "seed": 0, "imgaug": "default",
            "train_batch_size": 8, "val_batch_size": 8, "test_batch_size": 8,
            "num_epochs": 3, "num_workers": 0, "num_gpus": 0, "num_nodes": 1,
            "log_every_n_steps": 10, "check_val_every_n_epoch": 1,
            "num_sanity_val_steps": 0,
        },
        "optimizer": {
            "type": "AdamW", "accumulate_grad_batches": 1,
            "lr": 5e-5, "wd": 0.05, "warmup_pct": 0.15,
            "gamma": 0.95, "div_factor": 10, "scheduler": "cosine",
        },
        "data": {"data_dir": None},  # filled in per test
    }


def _run_beast_train(config, data_dir, output_dir, dataloader_mode="buggy"):
    """Run the BEAST training pipeline (mirrors beast/train.py + beast/api/model.py).

    dataloader_mode:
        "buggy"   — original: sampler= + batch_size=None
        "fix1"    — batch_sampler= (correct PyTorch API)
        "fix2"    — batch_sampler= + step-based val_check_interval
    """
    from torch.utils.data import DataLoader

    config = copy.deepcopy(config)
    config["data"]["data_dir"] = str(data_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # --- Replicate beast/api/model.py: Model.from_config + Model.train ---
    model = VisionTransformer(config)

    # --- Replicate beast/train.py: train() ---
    # chdir to output_dir (like Model.train does)
    old_cwd = os.getcwd()
    os.chdir(output_dir)

    try:
        pipe_params = expand_imgaug_str_to_dict(config["training"].get("imgaug", "none"))
        pipeline = imgaug_pipeline(pipe_params)
        dataset = BaseDataset(data_dir=config["data"]["data_dir"], imgaug_pipeline=pipeline)

        dm = BaseDataModule(
            dataset=dataset,
            train_batch_size=config["training"]["train_batch_size"],
            val_batch_size=config["training"]["val_batch_size"],
            test_batch_size=config["training"]["test_batch_size"],
            use_sampler=True,
            num_workers=config["training"]["num_workers"],
            train_probability=0.90,
            val_probability=0.05,
            seed=config["training"]["seed"],
        )

        if dataloader_mode == "buggy":
            # Original BEAST code: sampler= with batch_size=None
            def buggy_train_dl():
                sampler = ContrastBatchSampler(
                    dataset=dm.train_dataset,
                    batch_size=config["training"]["train_batch_size"],
                    seed=config["training"]["seed"],
                )
                return DataLoader(
                    dm.train_dataset,
                    batch_size=None,
                    sampler=sampler,
                    collate_fn=contrastive_collate_fn,
                    num_workers=0,
                )

            def buggy_val_dl():
                return DataLoader(
                    dm.val_dataset,
                    batch_size=config["training"]["val_batch_size"],
                    num_workers=0,
                )

            dm.train_dataloader = buggy_train_dl
            dm.val_dataloader = buggy_val_dl

        elif dataloader_mode in ("fix1", "fix2"):
            # Fixed: batch_sampler=
            def fixed_train_dl():
                sampler = ContrastBatchSampler(
                    dataset=dm.train_dataset,
                    batch_size=config["training"]["train_batch_size"],
                    seed=config["training"]["seed"],
                )
                return DataLoader(
                    dm.train_dataset,
                    batch_sampler=sampler,
                    collate_fn=contrastive_collate_fn,
                    num_workers=0,
                )

            def fixed_val_dl():
                val_sampler = ContrastBatchSampler(
                    dataset=dm.val_dataset,
                    batch_size=config["training"]["val_batch_size"],
                    seed=config["training"]["seed"],
                    shuffle=False,
                )
                return DataLoader(
                    dm.val_dataset,
                    batch_sampler=val_sampler,
                    collate_fn=contrastive_collate_fn,
                    num_workers=0,
                )

            dm.train_dataloader = fixed_train_dl
            dm.val_dataloader = fixed_val_dl

        # Compute steps_per_epoch (as beast/train.py does)
        num_epochs = config["training"]["num_epochs"]
        steps_per_epoch = int(np.ceil(
            len(dm.train_dataset)
            / config["training"]["train_batch_size"]
            / max(config["training"]["num_gpus"], 1)
            / config["training"]["num_nodes"]
        ))
        model.config["optimizer"]["steps_per_epoch"] = steps_per_epoch
        model.config["optimizer"]["total_steps"] = steps_per_epoch * num_epochs

        # Save config
        with open(output_dir / "config.yaml", "w") as f:
            yaml.dump(config, f)

        logger = pl.loggers.TensorBoardLogger("tb_logs", name="")
        callbacks = get_callbacks(lr_monitor=True)

        # Build Trainer kwargs
        trainer_kwargs = dict(
            accelerator="cpu",
            devices=1,
            max_epochs=num_epochs,
            min_epochs=num_epochs,
            log_every_n_steps=config["training"].get("log_every_n_steps", 10),
            callbacks=callbacks,
            logger=logger,
            accumulate_grad_batches=config["optimizer"].get("accumulate_grad_batches", 1),
            num_sanity_val_steps=config["training"].get("num_sanity_val_steps", 0),
            sync_batchnorm=True,
            use_distributed_sampler=False,
        )

        if dataloader_mode == "fix2":
            # Step-based val_check_interval workaround
            trainer_kwargs["check_val_every_n_epoch"] = None
            trainer_kwargs["val_check_interval"] = (
                len(dm.train_dataloader())
                * config["training"].get("check_val_every_n_epoch", 1)
            )
        else:
            trainer_kwargs["check_val_every_n_epoch"] = config["training"].get(
                "check_val_every_n_epoch", 1
            )

        trainer = pl.Trainer(**trainer_kwargs)
        trainer.fit(model=model, datamodule=dm)

        # Check results
        metrics = dict(trainer.callback_metrics)
        has_val = "val_loss" in metrics

        return has_val, metrics

    finally:
        os.chdir(old_cwd)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    # Create test data
    tmpdir = Path(tempfile.mkdtemp())
    data_dir = create_test_data(tmpdir / "frames")
    print(f"Test data: {data_dir}")

    results = {}
    modes = [
        ("buggy", "ORIGINAL: sampler= + batch_size=None"),
        ("fix1",  "FIX 1: batch_sampler= only"),
        ("fix2",  "FIX 2: batch_sampler= + step-based val_check_interval"),
    ]

    for mode, description in modes:
        print(f"\n{'='*60}")
        print(f"TEST: {description}")
        print(f"{'='*60}")

        output_dir = tmpdir / f"output_{mode}"
        has_val, metrics = _run_beast_train(
            _make_config(), data_dir, output_dir, dataloader_mode=mode,
        )

        metric_keys = [k for k in metrics.keys()]
        val_keys = [k for k in metric_keys if "val" in k]
        print(f"  Metrics: {metric_keys}")
        print(f"  Val keys: {val_keys}")
        if has_val:
            print(f"  val_loss = {metrics['val_loss']:.4f}")
            print(f"  Result: PASS")
        else:
            print(f"  Result: FAIL — validation never ran")
        results[mode] = has_val

    # Summary
    print(f"\n{'='*60}")
    print("SUMMARY")
    print(f"{'='*60}")
    for mode, description in modes:
        status = "PASS" if results[mode] else "FAIL (bug reproduced)"
        print(f"  {description}: {status}")

    if not results["buggy"] and results["fix2"]:
        print("\nBug confirmed and fix verified.")

    # Cleanup
    shutil.rmtree(tmpdir)


if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant