diff --git a/docs/Evaluation.md b/docs/Evaluation.md index bd109a1eb..983130315 100644 --- a/docs/Evaluation.md +++ b/docs/Evaluation.md @@ -21,9 +21,11 @@ This guide explains how we launch evaluations for OlmoEarth checkpoints and base 3. [Quick Start](#quick-start) 4. [KNN / Linear Probing](#knn--linear-probing) 5. [Finetune](#finetune-sweep) -6. [Monitoring & Outputs](#monitoring--outputs) -7. [Helpful Files](#helpful-files) -8. [Adding New Eval Datasets (Internal)](#adding-new-eval-datasets-internal) +6. [Embedding Diagnostics](#embedding-diagnostics) +7. [Tiling Artifact Diagnostics](#tiling-artifact-diagnostics) +8. [Monitoring & Outputs](#monitoring--outputs) +9. [Helpful Files](#helpful-files) +10. [Adding New Eval Datasets (Internal)](#adding-new-eval-datasets-internal) --- @@ -272,6 +274,148 @@ python -m olmoearth_pretrain.internal.full_eval_sweep_finetune \ --- +## Embedding Diagnostics + +Embedding diagnostics measure the geometric quality of encoder representations without requiring labeled data. They detect common self-supervised pretraining failure modes such as dimensional collapse, representation crowding, and patch uniformity loss. + +### What it measures + +| Metric | Healthy range | What it detects | +|--------|--------------|-----------------| +| `effective_rank` | > 0.7 × D | Dimensional collapse (few active SVD components) | +| `uniformity` | < -2.0 | How uniformly embeddings cover the hypersphere | +| `cosine_sim_mean` | < 0.3 | Representation crowding (all embeddings similar) | +| `intra_cosine_sim_mean` | < 0.5 | Patch collapse within images (bad for segmentation) | + +For spatial (patch-level) embeddings, metrics are reported with three prefixes: +- `global_*` — all patches flattened together +- `inter_*` — mean-pooled per image, then compared across images +- `intra_*` — patch diversity within each image + +### Running during training (in-loop) + +Embedding diagnostics are included in the default `build_trainer_config` in `scripts/official/script.py`. They run on a fixed subset of pretrain data (`pretrain_subset_128`) at the interval specified: + +```python +DownstreamTaskConfig( + dataset="pretrain_subset_128", + eval_mode=EvalMode.EMBEDDING_DIAGNOSTICS, + embedding_batch_size=4, + eval_interval=Duration.steps(20000), + h5py_dir=H5PY_DIR, + pretrain_max_samples=256, + input_modalities=[Modality.SENTINEL2_L2A.name, Modality.SENTINEL1.name, Modality.LANDSAT.name], +) +``` + +Metrics are logged to W&B under `eval_embed_diagnostics//`. + +### Running on saved checkpoints + +Use `checkpoint_sweep_evals.py` with the `EMBEDDING_DIAGNOSTICS_ONLY` env var: + +```bash +EMBEDDING_DIAGNOSTICS_ONLY=1 \ +TRAIN_SCRIPT_PATH=scripts/official/base.py \ +CHECKPOINT_DIR=/weka/.../checkpoints/my_run \ +torchrun olmoearth_pretrain/internal/checkpoint_sweep_evals.py \ + evaluate my_run_embed_diag local +``` + +Or via `all_evals.py`: + +```bash +EMBEDDING_DIAGNOSTICS_ONLY=1 \ +TRAIN_SCRIPT_PATH=scripts/official/base.py \ +python3 olmoearth_pretrain/internal/all_evals.py \ + launch_evaluate my_run_embed_diag ai2/saturn-cirrascale +``` + +### Interpreting results + +- **`effective_rank` dropping** → model is collapsing to fewer dimensions. Often happens with too-high learning rate or missing stop-gradient. +- **`cosine_sim_mean` near 1.0** → all embeddings point the same direction. Complete collapse. +- **`intra_cosine_sim_mean` near 1.0** → patches within images are identical. The model cannot distinguish spatial locations, so segmentation tasks will fail. + +--- + +## Tiling Artifact Diagnostics + +Tiling diagnostics detect spatial tiling and striping artifacts in encoder embeddings (see GitHub issue #499). These artifacts appear as periodic grid patterns when the model's spatial representation has systematic biases aligned with patch boundaries. + +### What it measures + +| Metric | Healthy value | Artifact signal | +|--------|--------------|-----------------| +| `row_col_var_ratio` | ~1.0 | Far from 1.0 → directional stripes (horizontal if >1, vertical if <1) | +| `fft_axis_energy_frac` | ~0.12 | > 0.25 → periodic grid artifacts | +| `fft_dominant_period_px` | — | Period of the strongest artifact in pixels | + +Additionally, a **PCA RGB image** is logged to W&B, showing the first 3 PCA components of a sample's spatial embeddings as an RGB image. Healthy embeddings look like a smooth, spatially-varying color map. Tiling artifacts appear as a visible grid or stripe pattern. + +### How it works + +1. **Row/column variance ratio**: Averages embeddings along rows and columns separately, then compares their variances. Isotropic embeddings have a ratio near 1.0; directional stripes cause large deviations. +2. **FFT axis energy**: Projects all patch embeddings to their first PCA component, computes a 2D FFT per sample, and measures what fraction of spectral energy lies on the horizontal and vertical frequency axes (excluding DC and the k=1 gradient). High axis energy means periodic grid patterns exist. +3. **PCA RGB**: Fits PCA on a single sample's [H, W, D] embeddings and maps the first 3 components to RGB channels. Logged as a `wandb.Image`. + +### Running during training (in-loop) + +Tiling diagnostics are included in the default `build_trainer_config` in `scripts/official/script.py` for 64px and 128px spatial sizes: + +```python +DownstreamTaskConfig( + dataset="pretrain_subset_128", # or pretrain_subset_64 + eval_mode=EvalMode.TILING_DIAGNOSTICS, + embedding_batch_size=32, + eval_interval=Duration.steps(20000), + h5py_dir=H5PY_DIR, + pretrain_max_samples=128, + patch_size=4, + input_modalities=[Modality.SENTINEL2_L2A.name], +) +``` + +Metrics appear in W&B under `eval_embed_diagnostics/tiling_64px/*` and `eval_embed_diagnostics/tiling_128px/*`. + +### Running on saved checkpoints + +Use `checkpoint_sweep_evals.py` with the `TILING_DIAGNOSTICS_ONLY` env var: + +```bash +TILING_DIAGNOSTICS_ONLY=1 \ +TRAIN_SCRIPT_PATH=scripts/official/base.py \ +CHECKPOINT_DIR=/weka/.../checkpoints/my_run \ +torchrun olmoearth_pretrain/internal/checkpoint_sweep_evals.py \ + evaluate my_run_tiling_diag local +``` + +Or launch on Beaker: + +```bash +TILING_DIAGNOSTICS_ONLY=1 \ +TRAIN_SCRIPT_PATH=scripts/official/base.py \ +CHECKPOINT_DIR=/weka/.../checkpoints/my_run \ +python3 olmoearth_pretrain/internal/checkpoint_sweep_evals.py \ + launch_evaluate my_run_tiling_diag ai2/saturn-cirrascale +``` + +### Interpreting results + +- **`fft_axis_energy_frac` > 0.25**: Likely tiling artifacts. Check the PCA RGB image for visible grid lines. +- **`row_col_var_ratio` far from 1.0**: Directional striping. Values > 5 suggest horizontal stripes; values < 0.2 suggest vertical stripes. +- **`fft_dominant_period_px` matches patch size multiples**: The artifact period aligning with the patch size (e.g. 16px for patch_size=4 at 4-patch intervals) confirms the artifact comes from the patch embedding or positional encoding. +- **PCA RGB image shows grid lines**: Visual confirmation. Compare early vs. late checkpoints — artifacts that persist or worsen indicate a systematic architecture issue rather than an early-training transient. + +### Relevant source files + +- [`evals/embedding_diagnostics.py`](../olmoearth_pretrain/evals/embedding_diagnostics.py) — Metric computation (`compute_tiling_artifact_metrics`, `pca_rgb_image`) +- [`evals/datasets/configs.py`](../olmoearth_pretrain/evals/datasets/configs.py) — `pretrain_subset_64` / `pretrain_subset_128` dataset configs +- [`train/callbacks/evaluator_callback.py`](../olmoearth_pretrain/train/callbacks/evaluator_callback.py) — `_val_tiling_diagnostics()` callback method +- [`internal/all_evals.py`](../olmoearth_pretrain/internal/all_evals.py) — `TILING_DIAG_TASKS` and `EMBED_DIAG_TASKS` task registries + +--- + ## Monitoring & Outputs - **W&B logging:** Both scripts default to `EVAL_WANDB_PROJECT`. Override with `--project_name` or disable W&B via `--trainer.callbacks.wandb.enabled=False`. diff --git a/olmoearth_pretrain/evals/datasets/__init__.py b/olmoearth_pretrain/evals/datasets/__init__.py index 6d5f9eb11..b5dceaef7 100644 --- a/olmoearth_pretrain/evals/datasets/__init__.py +++ b/olmoearth_pretrain/evals/datasets/__init__.py @@ -10,6 +10,7 @@ from olmoearth_pretrain.evals.studio_ingest.registry import get_dataset_entry from .breizhcrops import BreizhCropsDataset +from .configs import dataset_to_config from .floods_dataset import Sen1Floods11Dataset from .geobench_dataset import GeobenchDataset from .mados_dataset import MADOSDataset @@ -45,13 +46,16 @@ def get_eval_dataset( **kwargs: Any, ) -> Dataset: """Retrieve an eval dataset from the dataset name.""" - if eval_dataset == "pretrain_subset": + if eval_dataset.startswith("pretrain_subset"): + patch_size = kwargs.get("pretrain_patch_size", 4) + config = dataset_to_config(eval_dataset) + hw_p = config.height_width // patch_size return PretrainSubsetDataset( h5py_dir=kwargs["h5py_dir"], training_modalities=kwargs.get("training_modalities", input_modalities), max_samples=kwargs.get("max_samples", 512), - patch_size=kwargs.get("pretrain_patch_size", 4), - hw_p=kwargs.get("pretrain_hw_p", 8), + patch_size=patch_size, + hw_p=hw_p, seed=kwargs.get("pretrain_seed", 42), ) elif eval_dataset.startswith("m-"): diff --git a/olmoearth_pretrain/evals/datasets/configs.py b/olmoearth_pretrain/evals/datasets/configs.py index 731cb26e8..1c1fbab58 100644 --- a/olmoearth_pretrain/evals/datasets/configs.py +++ b/olmoearth_pretrain/evals/datasets/configs.py @@ -12,6 +12,8 @@ def get_eval_mode(task_type: TaskType) -> str: """Get the eval mode for a given task type.""" if task_type == TaskType.CLASSIFICATION: return "knn" + elif task_type == TaskType.DIAGNOSTIC: + return "embedding_diagnostics" else: return "linear_probe" @@ -48,18 +50,34 @@ def from_dict(cls, d: dict[str, Any]) -> "EvalDatasetConfig": return cls(**d) +_PRETRAIN_SUBSET_MODALITIES = [ + Modality.SENTINEL2_L2A.name, + Modality.SENTINEL1.name, + Modality.LANDSAT.name, +] + DATASET_TO_CONFIG = { - # Dummy config — only used for embedding diagnostics, not actual classification. + # Pretrain subset configs for embedding/tiling diagnostics at different spatial sizes. + # Uses DIAGNOSTIC so the eval wrapper preserves spatial dims [N, H, W, D]. + **{ + f"pretrain_subset_{px}": EvalDatasetConfig( + task_type=TaskType.DIAGNOSTIC, + imputes=[], + num_classes=1, + is_multilabel=False, + height_width=px, + supported_modalities=_PRETRAIN_SUBSET_MODALITIES, + ) + for px in (64, 128) + }, + # Backward-compat alias (defaults to 128px). "pretrain_subset": EvalDatasetConfig( - task_type=TaskType.CLASSIFICATION, + task_type=TaskType.DIAGNOSTIC, imputes=[], num_classes=1, is_multilabel=False, - supported_modalities=[ - Modality.SENTINEL2_L2A.name, - Modality.SENTINEL1.name, - Modality.LANDSAT.name, - ], + height_width=128, + supported_modalities=_PRETRAIN_SUBSET_MODALITIES, ), "m-eurosat": EvalDatasetConfig( task_type=TaskType.CLASSIFICATION, diff --git a/olmoearth_pretrain/evals/datasets/pretrain_subset.py b/olmoearth_pretrain/evals/datasets/pretrain_subset.py index bcfcf13e0..b7c03b1cc 100644 --- a/olmoearth_pretrain/evals/datasets/pretrain_subset.py +++ b/olmoearth_pretrain/evals/datasets/pretrain_subset.py @@ -72,5 +72,6 @@ def __getitem__(self, idx: int) -> tuple[MaskedOlmoEarthSample, torch.Tensor]: ) _, sample = self._dataset[args] masked = MaskedOlmoEarthSample.from_olmoearthsample(sample) - dummy_label = torch.tensor(0, dtype=torch.long) + pixel_size = self.hw_p * self.patch_size + dummy_label = torch.zeros(pixel_size, pixel_size, dtype=torch.long) return masked, dummy_label diff --git a/olmoearth_pretrain/evals/embedding_diagnostics.py b/olmoearth_pretrain/evals/embedding_diagnostics.py index c2c844878..e9f7f8706 100644 --- a/olmoearth_pretrain/evals/embedding_diagnostics.py +++ b/olmoearth_pretrain/evals/embedding_diagnostics.py @@ -1,8 +1,11 @@ -"""Embedding quality diagnostics for detecting representation collapse. +"""Embedding quality diagnostics for detecting representation collapse and tiling artifacts. Computes geometry metrics on embedding matrices to diagnose failure modes in self-supervised pretraining (dimensional collapse, crowding, etc.). +Also detects spatial tiling/striping artifacts (see GitHub issue #499) by measuring +row/column variance anisotropy and periodic energy in the Fourier domain. + Supports two embedding shapes: - [N, D]: image-level (classification). One embedding per sample. - [N, P, D] or [N, H, W, D]: patch-level (segmentation). Multiple patches per sample. @@ -16,7 +19,9 @@ import logging +import numpy as np import torch +from sklearn.decomposition import PCA from torch import Tensor logger = logging.getLogger(__name__) @@ -201,3 +206,165 @@ def compute_spatial_embedding_diagnostics(embeddings: Tensor) -> dict[str, float metrics[f"intra_{k}"] = v return metrics + + +# --------------------------------------------------------------------------- +# Tiling / striping artifact detection (GitHub issue #499) +# --------------------------------------------------------------------------- + +MAX_TILING_SAMPLES = 64 + + +def _row_col_variance_ratio(embeddings: Tensor) -> float: + """Detect striping via variance of row-means vs column-means. + + Horizontal stripes → high row-mean variance relative to column-mean variance. + Vertical stripes → high column-mean variance relative to row-mean variance. + + Args: + embeddings: [N, H, W, D] spatial embeddings. + + Returns: + Ratio of row-variance to col-variance (1.0 = isotropic). + """ + emb = embeddings.float() + row_means = emb.mean(dim=2) # [N, H, D] + col_means = emb.mean(dim=1) # [N, W, D] + + row_var = row_means.var(dim=1).mean().item() + col_var = col_means.var(dim=1).mean().item() + + return row_var / (col_var + 1e-12) + + +def _fourier_grid_energy(embeddings: Tensor, patch_size: int) -> dict[str, float]: + """Detect periodic tiling artifacts via 2D FFT on the first PCA component. + + Computes the fraction of spectral energy concentrated on the + horizontal and vertical axes of the frequency domain (excluding DC). + Also identifies the dominant frequency and its period in pixels. + + Args: + embeddings: [N, H, W, D] spatial embeddings (H, W are in patch space). + patch_size: pixel size of each patch, used to convert period to pixels. + + Returns: + fft_axis_energy_frac: fraction of energy on grid axes (~0.12 healthy, >0.25 artifacts). + fft_dominant_period_px: period of the strongest axis frequency in pixels. + """ + emb = embeddings.float() + n, h, w, d = emb.shape + flat = emb.reshape(-1, d) + pca = PCA(n_components=1) + pc1 = pca.fit_transform(flat.cpu().numpy()) # [N*H*W, 1] + pc1_map = torch.from_numpy(pc1.reshape(n, h, w)) # [N, H, W] + + fft_2d = torch.fft.fft2(pc1_map, norm="ortho") + mag = fft_2d.abs().mean(dim=0) # [H, W] + + mag[0, 0] = 0.0 + + total_energy = mag.sum().item() + 1e-12 + h_axis_energy = mag[:, 0].sum().item() + w_axis_energy = mag[0, :].sum().item() + axis_energy = h_axis_energy + w_axis_energy + + # Find dominant axis-aligned frequency, skipping k=1 (just the overall + # spatial gradient) to find actual periodic artifacts. + min_k = 2 + axis_mags = [] + for k in range(min_k, h): + axis_mags.append((mag[k, 0].item(), h / k)) + for k in range(min_k, w): + axis_mags.append((mag[0, k].item(), w / k)) + + dominant_period_patches = 0.0 + if axis_mags: + _, dominant_period_patches = max(axis_mags, key=lambda x: x[0]) + + return { + "fft_axis_energy_frac": axis_energy / total_energy, + "fft_dominant_period_px": dominant_period_patches * patch_size, + } + + +def compute_tiling_artifact_metrics( + embeddings: Tensor, patch_size: int = 4 +) -> dict[str, float]: + """Compute metrics that detect spatial tiling/striping artifacts. + + Returns 3 metrics: + - row_col_var_ratio: 1.0 = isotropic, far from 1.0 = directional stripes + - fft_axis_energy_frac: ~0.12 = healthy, >0.25 = periodic grid artifacts + - fft_dominant_period_px: period of strongest artifact in pixels + + Args: + embeddings: [N, H, W, D] spatial embeddings (H, W in patch space). + patch_size: pixel size of each patch for converting periods. + + Returns empty dict if input doesn't have spatial dimensions (H, W >= 2). + """ + if embeddings.ndim != 4: + logger.warning( + "Tiling artifact metrics require [N, H, W, D] embeddings, " + f"got shape {embeddings.shape}" + ) + return {} + + n, h, w, _d = embeddings.shape + if h < 2 or w < 2: + logger.warning(f"Spatial dims too small for tiling metrics: H={h}, W={w}") + return {} + + if n > MAX_TILING_SAMPLES: + idx = torch.randperm(n, device=embeddings.device)[:MAX_TILING_SAMPLES] + embeddings = embeddings[idx] + + metrics: dict[str, float] = {} + + metrics["row_col_var_ratio"] = _row_col_variance_ratio(embeddings) + + if h >= 4 and w >= 4: + fft_stats = _fourier_grid_energy(embeddings, patch_size) + metrics["fft_axis_energy_frac"] = fft_stats["fft_axis_energy_frac"] + metrics["fft_dominant_period_px"] = fft_stats["fft_dominant_period_px"] + + return metrics + + +def pca_rgb_image(embeddings: Tensor) -> np.ndarray: + """Render the first 3 PCA components of spatial embeddings as an RGB image. + + Takes a single image's spatial embeddings [H, W, D] and returns + an [H, W, 3] uint8 array suitable for wandb.Image / matplotlib. + + If called with [N, H, W, D], uses the first sample. + """ + if embeddings.ndim == 4: + embeddings = embeddings[0] + if embeddings.ndim != 3: + raise ValueError(f"Expected [H, W, D] or [N, H, W, D], got {embeddings.shape}") + + h, w, d = embeddings.shape + flat = embeddings.reshape(-1, d).float().cpu().numpy() + + n_components = min(3, d) + pca = PCA(n_components=n_components) + components = pca.fit_transform(flat) # [H*W, 3] + + # Normalize each component to [0, 1] + for i in range(n_components): + c = components[:, i] + cmin, cmax = c.min(), c.max() + if cmax - cmin > 1e-8: + components[:, i] = (c - cmin) / (cmax - cmin) + else: + components[:, i] = 0.5 + + # Pad to 3 channels if fewer + if n_components < 3: + pad = np.zeros((components.shape[0], 3 - n_components), dtype=np.float32) + components = np.concatenate([components, pad], axis=1) + + rgb = (components.reshape(h, w, 3) * 255).astype(np.uint8) + return rgb diff --git a/olmoearth_pretrain/evals/eval_wrapper.py b/olmoearth_pretrain/evals/eval_wrapper.py index 4c535f94a..60c3d42b3 100644 --- a/olmoearth_pretrain/evals/eval_wrapper.py +++ b/olmoearth_pretrain/evals/eval_wrapper.py @@ -66,7 +66,7 @@ def __init__( self.patch_size = patch_size self.pooling_type = pooling_type self.concat_features = concat_features - self.spatial_pool = task_type == TaskType.SEGMENTATION + self.spatial_pool = task_type in (TaskType.SEGMENTATION, TaskType.DIAGNOSTIC) self.use_pooled_tokens = use_pooled_tokens if self.use_pooled_tokens: assert isinstance(self.model, EncodeEarlyAttnPool), ( diff --git a/olmoearth_pretrain/evals/metrics.py b/olmoearth_pretrain/evals/metrics.py index db2bf3993..90d5527c2 100644 --- a/olmoearth_pretrain/evals/metrics.py +++ b/olmoearth_pretrain/evals/metrics.py @@ -6,6 +6,7 @@ from enum import StrEnum from typing import Any +import numpy as np import torch from sklearn.metrics import accuracy_score, f1_score @@ -37,6 +38,7 @@ class EvalTaskResult: bootstrap_stats: dict[str, Any] = field(default_factory=dict) eval_time: float | None = None embedding_diagnostics: dict[str, float] = field(default_factory=dict) + pca_rgb: np.ndarray | None = None @dataclass diff --git a/olmoearth_pretrain/evals/task_types.py b/olmoearth_pretrain/evals/task_types.py index ecc81bc0b..0b48aa941 100644 --- a/olmoearth_pretrain/evals/task_types.py +++ b/olmoearth_pretrain/evals/task_types.py @@ -8,6 +8,7 @@ class TaskType(StrEnum): CLASSIFICATION = "classification" SEGMENTATION = "segmentation" + DIAGNOSTIC = "diagnostic" class SplitName(StrEnum): diff --git a/olmoearth_pretrain/internal/all_evals.py b/olmoearth_pretrain/internal/all_evals.py index b83dc6c56..8f6e99dcc 100644 --- a/olmoearth_pretrain/internal/all_evals.py +++ b/olmoearth_pretrain/internal/all_evals.py @@ -401,9 +401,11 @@ def load_user_module(path: str) -> Any: ), } +H5PY_DIR = "/weka/dfive-default/helios/dataset/osm_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_gse_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_worldpop_wri_canopy_height_map/1138828" + EMBED_DIAG_TASKS = { - "pretrain_subset": DownstreamTaskConfig( - dataset="pretrain_subset", + "pretrain_subset_128": DownstreamTaskConfig( + dataset="pretrain_subset_128", embedding_batch_size=4, num_workers=2, pooling_type=PoolingType.MEAN, @@ -415,11 +417,27 @@ def load_user_module(path: str) -> Any: Modality.LANDSAT.name, ], eval_mode=EvalMode.EMBEDDING_DIAGNOSTICS, - h5py_dir="/weka/dfive-default/helios/dataset/osm_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_gse_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_worldpop_wri_canopy_height_map/1138828", + h5py_dir=H5PY_DIR, pretrain_max_samples=256, ), } +TILING_DIAG_TASKS = { + f"tiling_{px}px": DownstreamTaskConfig( + dataset=f"pretrain_subset_{px}", + embedding_batch_size=32, + num_workers=0, + pooling_type=PoolingType.MEAN, + eval_mode=EvalMode.TILING_DIAGNOSTICS, + eval_interval=Duration.epochs(1), + h5py_dir=H5PY_DIR, + pretrain_max_samples=128, + patch_size=4, + input_modalities=[Modality.SENTINEL2_L2A.name], + ) + for px in (64, 128) +} + FT_EVAL_TASKS = { "m_eurosat": DownstreamTaskConfig( dataset="m-eurosat", diff --git a/olmoearth_pretrain/internal/checkpoint_sweep_evals.py b/olmoearth_pretrain/internal/checkpoint_sweep_evals.py index 9058c1469..f982b71b0 100644 --- a/olmoearth_pretrain/internal/checkpoint_sweep_evals.py +++ b/olmoearth_pretrain/internal/checkpoint_sweep_evals.py @@ -48,6 +48,7 @@ from olmoearth_pretrain.internal.all_evals import ( EMBED_DIAG_TASKS, EVAL_TASKS, + TILING_DIAG_TASKS, load_user_module, ) from olmoearth_pretrain.internal.constants import EVAL_WANDB_PROJECT, WANDB_ENTITY @@ -190,6 +191,21 @@ def evaluate_checkpoints( f"eval_embed_diagnostics/{evaluator.evaluation_name}/{k}" ] = v + if ( + result.pca_rgb is not None + and wandb_callback.enabled + and get_rank() == 0 + ): + wandb_callback.wandb.log( + { + f"eval_embed_diagnostics/{evaluator.evaluation_name}/pca_rgb": wandb_callback.wandb.Image( + result.pca_rgb, + caption=f"PCA RGB — {evaluator.evaluation_name} (step {step_num})", + ), + "checkpoint_step": step_num, + } + ) + metrics[f"eval_time/{evaluator.evaluation_name}"] = eval_time logger.info( @@ -213,7 +229,9 @@ def evaluate_checkpoints( def _get_eval_tasks() -> dict: - """Select task set based on EMBEDDING_DIAGNOSTICS_ONLY env var.""" + """Select task set based on environment variables.""" + if os.environ.get("TILING_DIAGNOSTICS_ONLY"): + return TILING_DIAG_TASKS if os.environ.get("EMBEDDING_DIAGNOSTICS_ONLY"): return EMBED_DIAG_TASKS return EVAL_TASKS diff --git a/olmoearth_pretrain/train/callbacks/evaluator_callback.py b/olmoearth_pretrain/train/callbacks/evaluator_callback.py index 1a5c3b202..653f8d077 100644 --- a/olmoearth_pretrain/train/callbacks/evaluator_callback.py +++ b/olmoearth_pretrain/train/callbacks/evaluator_callback.py @@ -33,6 +33,8 @@ from olmoearth_pretrain.evals.embedding_diagnostics import ( compute_embedding_diagnostics, compute_spatial_embedding_diagnostics, + compute_tiling_artifact_metrics, + pca_rgb_image, ) from olmoearth_pretrain.evals.embedding_transforms import ( dequantize_embeddings, @@ -65,6 +67,7 @@ class EvalMode(StrEnum): LINEAR_PROBE = "linear_probe" FINETUNE = "finetune" EMBEDDING_DIAGNOSTICS = "embedding_diagnostics" + TILING_DIAGNOSTICS = "tiling_diagnostics" @dataclass @@ -264,7 +267,7 @@ def _get_data_loader( worker_init_fn = partial(_seed_worker, base_seed=split_seed) extra_kwargs: dict[str, Any] = {} - if self.dataset == "pretrain_subset" and self.h5py_dir is not None: + if self.dataset.startswith("pretrain_subset") and self.h5py_dir is not None: extra_kwargs["h5py_dir"] = self.h5py_dir extra_kwargs["training_modalities"] = self.input_modalities extra_kwargs["max_samples"] = self.pretrain_max_samples @@ -292,7 +295,7 @@ def _get_embeddings( self, data_loader: DataLoader, is_train: bool ) -> tuple[torch.Tensor, torch.Tensor]: """Get the embeddings for the given data loader.""" - print( + logger.info( f"Getting embeddings for {self.dataset} with norm method {self.norm_method}" ) if hasattr(self.trainer.train_module.model, "encoder"): @@ -543,7 +546,7 @@ def _val_finetune(self) -> EvalTaskResult: return result def _val_embedding_diagnostics(self) -> EvalTaskResult: - """Compute embedding diagnostics only (no downstream task).""" + """Compute embedding geometry diagnostics (effective rank, uniformity, etc.).""" logger.info(f"Computing embedding diagnostics for {self.dataset}") data_loader = self._get_data_loader("train", self.embedding_batch_size) embeddings, _ = self._get_embeddings(data_loader, is_train=False) @@ -559,10 +562,47 @@ def _val_embedding_diagnostics(self) -> EvalTaskResult: result.embedding_diagnostics = diagnostics return result + def _val_tiling_diagnostics(self) -> EvalTaskResult: + """Compute tiling/striping artifact metrics and PCA RGB visualization. + + Requires spatial embeddings [N, H, W, D]. + """ + logger.info(f"Computing tiling diagnostics for {self.dataset}") + data_loader = self._get_data_loader("train", self.embedding_batch_size) + embeddings, _ = self._get_embeddings(data_loader, is_train=False) + logger.info(f"Embeddings shape for {self.dataset}: {embeddings.shape}") + + if embeddings.ndim != 4: + raise ValueError( + f"Tiling diagnostics requires [N, H, W, D] embeddings, " + f"got shape {embeddings.shape}. Use a segmentation-type dataset." + ) + + diagnostics = compute_tiling_artifact_metrics( + embeddings, patch_size=self.patch_size + ) + logger.info(f"Tiling artifact metrics for {self.dataset}: {diagnostics}") + + pca_rgb = None + try: + pca_rgb = pca_rgb_image(embeddings) + except Exception: + logger.warning( + f"Failed to generate PCA RGB image for {self.dataset}", + exc_info=True, + ) + + result = EvalTaskResult(val_result=None, test_result=None) + result.embedding_diagnostics = diagnostics + result.pca_rgb = pca_rgb + return result + def val(self) -> EvalTaskResult: """Validate the model on the downstream task.""" if self.eval_mode == EvalMode.EMBEDDING_DIAGNOSTICS: return self._val_embedding_diagnostics() + elif self.eval_mode == EvalMode.TILING_DIAGNOSTICS: + return self._val_tiling_diagnostics() elif self.eval_mode in (EvalMode.KNN, EvalMode.LINEAR_PROBE): return self._val_embed_probe() elif self.eval_mode == EvalMode.FINETUNE: @@ -847,7 +887,24 @@ def _perform_eval(self, evaluator: DownstreamEvaluator) -> EvalTaskResult: f"eval_embed_diagnostics/{evaluator.evaluation_name}/{metric_name}", metric_value, ) - + if result.pca_rgb is not None: + try: + wandb_callback = next( + cb + for cb in self.trainer._iter_callbacks() + if isinstance(cb, OlmoEarthWandBCallback) + ) + if wandb_callback.enabled: + wandb_callback.wandb.log( + { + f"eval_embed_diagnostics/{evaluator.evaluation_name}/pca_rgb": wandb_callback.wandb.Image( + result.pca_rgb, + caption=f"PCA RGB — {evaluator.evaluation_name} (step {self.step})", + ) + } + ) + except StopIteration: + logger.debug("No WandB callback found, skipping PCA RGB logging") eval_time = time.monotonic() - start_time self.trainer.record_metric(f"eval_time/{evaluator.evaluation_name}", eval_time) logger.info( @@ -918,9 +975,9 @@ def build(self, trainer: Trainer) -> Callback | None: continue config = dataset_to_config(task.dataset) - if ( - config.task_type == TaskType.SEGMENTATION - and task.eval_mode != EvalMode.EMBEDDING_DIAGNOSTICS + if config.task_type == TaskType.SEGMENTATION and task.eval_mode not in ( + EvalMode.EMBEDDING_DIAGNOSTICS, + EvalMode.TILING_DIAGNOSTICS, ): if task.probe_lr is None and task.ft_lr is None: raise ValueError( diff --git a/scripts/official/script.py b/scripts/official/script.py index b91d61dd3..da4ee434f 100644 --- a/scripts/official/script.py +++ b/scripts/official/script.py @@ -37,7 +37,10 @@ OlmoEarthSpeedMonitorCallback, OlmoEarthWandBCallback, ) -from olmoearth_pretrain.train.callbacks.evaluator_callback import DownstreamTaskConfig +from olmoearth_pretrain.train.callbacks.evaluator_callback import ( + DownstreamTaskConfig, + EvalMode, +) from olmoearth_pretrain.train.loss import LossConfig from olmoearth_pretrain.train.masking import MaskingConfig from olmoearth_pretrain.train.train_module.contrastive_latentmim import ( @@ -48,6 +51,7 @@ MAX_PATCH_SIZE = 8 MIN_PATCH_SIZE = 1 +H5PY_DIR = "/weka/dfive-default/helios/dataset/osm_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_gse_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_worldpop_wri_canopy_height_map/1138828" def build_common_components( @@ -227,6 +231,21 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: input_modalities=[Modality.SENTINEL2_L2A.name], epochs=50, ), + **{ + f"tiling_{px}px": DownstreamTaskConfig( + dataset=f"pretrain_subset_{px}", + embedding_batch_size=32, + num_workers=0, + pooling_type=PoolingType.MEAN, + eval_mode=EvalMode.TILING_DIAGNOSTICS, + eval_interval=Duration.steps(20000), + h5py_dir=H5PY_DIR, + pretrain_max_samples=128, + patch_size=4, + input_modalities=[Modality.SENTINEL2_L2A.name], + ) + for px in (64, 128) + }, } trainer_config = ( TrainerConfig( diff --git a/scripts/tools/get_max_eval_metrics_from_wandb.py b/scripts/tools/get_max_eval_metrics_from_wandb.py index 6bb76015e..5d658a569 100644 --- a/scripts/tools/get_max_eval_metrics_from_wandb.py +++ b/scripts/tools/get_max_eval_metrics_from_wandb.py @@ -8,8 +8,8 @@ import numpy as np import pandas as pd -import wandb +import wandb from olmoearth_pretrain.evals.models import ( MODELS_WITH_MULTIPLE_SIZES, BaselineModelName, diff --git a/tests/unit/eval/test_embedding_diagnostics.py b/tests/unit/eval/test_embedding_diagnostics.py index 90940bc22..e457a4c03 100644 --- a/tests/unit/eval/test_embedding_diagnostics.py +++ b/tests/unit/eval/test_embedding_diagnostics.py @@ -1,14 +1,17 @@ """Unit tests for embedding diagnostics.""" +import numpy as np import pytest import torch from olmoearth_pretrain.evals.embedding_diagnostics import ( compute_embedding_diagnostics, compute_spatial_embedding_diagnostics, + compute_tiling_artifact_metrics, effective_rank, embedding_norm_stats, pairwise_cosine_stats, + pca_rgb_image, uniformity, ) @@ -171,3 +174,98 @@ def test_rejects_2d(self) -> None: """2D input raises ValueError.""" with pytest.raises(ValueError, match="3\\+ dim"): compute_spatial_embedding_diagnostics(torch.randn(10, 64)) + + +class TestTilingArtifactMetrics: + """Tests for tiling artifact detection metrics.""" + + def test_returns_all_keys(self) -> None: + """All 3 key metrics are present for 4D input.""" + embeddings = torch.randn(8, 8, 8, 32) + metrics = compute_tiling_artifact_metrics(embeddings) + expected_keys = { + "row_col_var_ratio", + "fft_axis_energy_frac", + "fft_dominant_period_px", + } + assert expected_keys == set(metrics.keys()) + + def test_isotropic_random_embeddings(self) -> None: + """Random embeddings should have var ratio near 1.""" + embeddings = torch.randn(16, 16, 16, 64) + metrics = compute_tiling_artifact_metrics(embeddings) + assert 0.5 < metrics["row_col_var_ratio"] < 2.0 + + def test_horizontal_stripes_detected(self) -> None: + """Embeddings with horizontal stripes have high row_col_var_ratio.""" + h, w, d = 16, 16, 32 + row_pattern = torch.randn(1, h, 1, d).expand(8, h, w, d) + # Add different base per sample + embeddings = row_pattern + torch.randn(8, 1, 1, d) * 0.1 + metrics = compute_tiling_artifact_metrics(embeddings) + assert metrics["row_col_var_ratio"] > 5.0 + + def test_vertical_stripes_detected(self) -> None: + """Embeddings with vertical stripes have low row_col_var_ratio.""" + h, w, d = 16, 16, 32 + col_pattern = torch.randn(1, 1, w, d).expand(8, h, w, d) + embeddings = col_pattern + torch.randn(8, 1, 1, d) * 0.1 + metrics = compute_tiling_artifact_metrics(embeddings) + assert metrics["row_col_var_ratio"] < 0.2 + + def test_rejects_non_4d(self) -> None: + """Non-4D input returns empty dict.""" + metrics = compute_tiling_artifact_metrics(torch.randn(10, 16, 32)) + assert metrics == {} + + def test_small_spatial_skips_fft(self) -> None: + """Spatial dims < 4 skip FFT metrics.""" + embeddings = torch.randn(8, 3, 3, 32) + metrics = compute_tiling_artifact_metrics(embeddings) + assert "fft_axis_energy_frac" not in metrics + assert "row_col_var_ratio" in metrics + + def test_periodic_stripes_fft(self) -> None: + """Periodic vertical stripes produce high FFT axis energy.""" + h, w, d = 16, 16, 4 + patch_size = 4 + base = torch.randn(8, h, w, d) * 0.01 + # Add periodic vertical pattern (period=4 patches = 16px) + for col in range(w): + base[:, :, col, 0] += 10.0 * torch.sin( + torch.tensor(2.0 * torch.pi * col / 4.0) + ) + metrics = compute_tiling_artifact_metrics(base, patch_size=patch_size) + assert metrics["fft_axis_energy_frac"] > 0.1 + # Dominant period should be 4 patches * 4 px = 16 px + assert abs(metrics["fft_dominant_period_px"] - 16.0) < 1.0 + + +class TestPcaRgbImage: + """Tests for PCA RGB visualization.""" + + def test_output_shape_3d(self) -> None: + """3D input [H, W, D] returns [H, W, 3] uint8.""" + emb = torch.randn(8, 8, 32) + rgb = pca_rgb_image(emb) + assert rgb.shape == (8, 8, 3) + assert rgb.dtype == np.uint8 + + def test_output_shape_4d(self) -> None: + """4D input [N, H, W, D] uses first sample, returns [H, W, 3].""" + emb = torch.randn(4, 8, 8, 32) + rgb = pca_rgb_image(emb) + assert rgb.shape == (8, 8, 3) + assert rgb.dtype == np.uint8 + + def test_values_in_range(self) -> None: + """Output values are in [0, 255].""" + emb = torch.randn(12, 12, 64) + rgb = pca_rgb_image(emb) + assert rgb.min() >= 0 + assert rgb.max() <= 255 + + def test_rejects_2d(self) -> None: + """2D input raises ValueError.""" + with pytest.raises(ValueError, match="Expected"): + pca_rgb_image(torch.randn(100, 32))