Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 147 additions & 3 deletions docs/Evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ This guide explains how we launch evaluations for OlmoEarth checkpoints and base
3. [Quick Start](#quick-start)
4. [KNN / Linear Probing](#knn--linear-probing)
5. [Finetune](#finetune-sweep)
6. [Monitoring & Outputs](#monitoring--outputs)
7. [Helpful Files](#helpful-files)
8. [Adding New Eval Datasets (Internal)](#adding-new-eval-datasets-internal)
6. [Embedding Diagnostics](#embedding-diagnostics)
7. [Tiling Artifact Diagnostics](#tiling-artifact-diagnostics)
8. [Monitoring & Outputs](#monitoring--outputs)
9. [Helpful Files](#helpful-files)
10. [Adding New Eval Datasets (Internal)](#adding-new-eval-datasets-internal)

---

Expand Down Expand Up @@ -272,6 +274,148 @@ python -m olmoearth_pretrain.internal.full_eval_sweep_finetune \

---

## Embedding Diagnostics

Embedding diagnostics measure the geometric quality of encoder representations without requiring labeled data. They detect common self-supervised pretraining failure modes such as dimensional collapse, representation crowding, and patch uniformity loss.

### What it measures

| Metric | Healthy range | What it detects |
|--------|--------------|-----------------|
| `effective_rank` | > 0.7 × D | Dimensional collapse (few active SVD components) |
| `uniformity` | < -2.0 | How uniformly embeddings cover the hypersphere |
| `cosine_sim_mean` | < 0.3 | Representation crowding (all embeddings similar) |
| `intra_cosine_sim_mean` | < 0.5 | Patch collapse within images (bad for segmentation) |

For spatial (patch-level) embeddings, metrics are reported with three prefixes:
- `global_*` — all patches flattened together
- `inter_*` — mean-pooled per image, then compared across images
- `intra_*` — patch diversity within each image

### Running during training (in-loop)

Embedding diagnostics are included in the default `build_trainer_config` in `scripts/official/script.py`. They run on a fixed subset of pretrain data (`pretrain_subset_128`) at the interval specified:

```python
DownstreamTaskConfig(
dataset="pretrain_subset_128",
eval_mode=EvalMode.EMBEDDING_DIAGNOSTICS,
embedding_batch_size=4,
eval_interval=Duration.steps(20000),
h5py_dir=H5PY_DIR,
pretrain_max_samples=256,
input_modalities=[Modality.SENTINEL2_L2A.name, Modality.SENTINEL1.name, Modality.LANDSAT.name],
)
```

Metrics are logged to W&B under `eval_embed_diagnostics/<task_name>/<metric>`.

### Running on saved checkpoints

Use `checkpoint_sweep_evals.py` with the `EMBEDDING_DIAGNOSTICS_ONLY` env var:

```bash
EMBEDDING_DIAGNOSTICS_ONLY=1 \
TRAIN_SCRIPT_PATH=scripts/official/base.py \
CHECKPOINT_DIR=/weka/.../checkpoints/my_run \
torchrun olmoearth_pretrain/internal/checkpoint_sweep_evals.py \
evaluate my_run_embed_diag local
```

Or via `all_evals.py`:

```bash
EMBEDDING_DIAGNOSTICS_ONLY=1 \
TRAIN_SCRIPT_PATH=scripts/official/base.py \
python3 olmoearth_pretrain/internal/all_evals.py \
launch_evaluate my_run_embed_diag ai2/saturn-cirrascale
```

### Interpreting results

- **`effective_rank` dropping** → model is collapsing to fewer dimensions. Often happens with too-high learning rate or missing stop-gradient.
- **`cosine_sim_mean` near 1.0** → all embeddings point the same direction. Complete collapse.
- **`intra_cosine_sim_mean` near 1.0** → patches within images are identical. The model cannot distinguish spatial locations, so segmentation tasks will fail.

---

## Tiling Artifact Diagnostics

Tiling diagnostics detect spatial tiling and striping artifacts in encoder embeddings (see GitHub issue #499). These artifacts appear as periodic grid patterns when the model's spatial representation has systematic biases aligned with patch boundaries.

### What it measures

| Metric | Healthy value | Artifact signal |
|--------|--------------|-----------------|
| `row_col_var_ratio` | ~1.0 | Far from 1.0 → directional stripes (horizontal if >1, vertical if <1) |
| `fft_axis_energy_frac` | ~0.12 | > 0.25 → periodic grid artifacts |
| `fft_dominant_period_px` | — | Period of the strongest artifact in pixels |

Additionally, a **PCA RGB image** is logged to W&B, showing the first 3 PCA components of a sample's spatial embeddings as an RGB image. Healthy embeddings look like a smooth, spatially-varying color map. Tiling artifacts appear as a visible grid or stripe pattern.

### How it works

1. **Row/column variance ratio**: Averages embeddings along rows and columns separately, then compares their variances. Isotropic embeddings have a ratio near 1.0; directional stripes cause large deviations.
2. **FFT axis energy**: Projects all patch embeddings to their first PCA component, computes a 2D FFT per sample, and measures what fraction of spectral energy lies on the horizontal and vertical frequency axes (excluding DC and the k=1 gradient). High axis energy means periodic grid patterns exist.
3. **PCA RGB**: Fits PCA on a single sample's [H, W, D] embeddings and maps the first 3 components to RGB channels. Logged as a `wandb.Image`.

### Running during training (in-loop)

Tiling diagnostics are included in the default `build_trainer_config` in `scripts/official/script.py` for 64px and 128px spatial sizes:

```python
DownstreamTaskConfig(
dataset="pretrain_subset_128", # or pretrain_subset_64
eval_mode=EvalMode.TILING_DIAGNOSTICS,
embedding_batch_size=32,
eval_interval=Duration.steps(20000),
h5py_dir=H5PY_DIR,
pretrain_max_samples=128,
patch_size=4,
input_modalities=[Modality.SENTINEL2_L2A.name],
)
```

Metrics appear in W&B under `eval_embed_diagnostics/tiling_64px/*` and `eval_embed_diagnostics/tiling_128px/*`.

### Running on saved checkpoints

Use `checkpoint_sweep_evals.py` with the `TILING_DIAGNOSTICS_ONLY` env var:

```bash
TILING_DIAGNOSTICS_ONLY=1 \
TRAIN_SCRIPT_PATH=scripts/official/base.py \
CHECKPOINT_DIR=/weka/.../checkpoints/my_run \
torchrun olmoearth_pretrain/internal/checkpoint_sweep_evals.py \
evaluate my_run_tiling_diag local
```

Or launch on Beaker:

```bash
TILING_DIAGNOSTICS_ONLY=1 \
TRAIN_SCRIPT_PATH=scripts/official/base.py \
CHECKPOINT_DIR=/weka/.../checkpoints/my_run \
python3 olmoearth_pretrain/internal/checkpoint_sweep_evals.py \
launch_evaluate my_run_tiling_diag ai2/saturn-cirrascale
```

### Interpreting results

- **`fft_axis_energy_frac` > 0.25**: Likely tiling artifacts. Check the PCA RGB image for visible grid lines.
- **`row_col_var_ratio` far from 1.0**: Directional striping. Values > 5 suggest horizontal stripes; values < 0.2 suggest vertical stripes.
- **`fft_dominant_period_px` matches patch size multiples**: The artifact period aligning with the patch size (e.g. 16px for patch_size=4 at 4-patch intervals) confirms the artifact comes from the patch embedding or positional encoding.
- **PCA RGB image shows grid lines**: Visual confirmation. Compare early vs. late checkpoints — artifacts that persist or worsen indicate a systematic architecture issue rather than an early-training transient.

### Relevant source files

- [`evals/embedding_diagnostics.py`](../olmoearth_pretrain/evals/embedding_diagnostics.py) — Metric computation (`compute_tiling_artifact_metrics`, `pca_rgb_image`)
- [`evals/datasets/configs.py`](../olmoearth_pretrain/evals/datasets/configs.py) — `pretrain_subset_64` / `pretrain_subset_128` dataset configs
- [`train/callbacks/evaluator_callback.py`](../olmoearth_pretrain/train/callbacks/evaluator_callback.py) — `_val_tiling_diagnostics()` callback method
- [`internal/all_evals.py`](../olmoearth_pretrain/internal/all_evals.py) — `TILING_DIAG_TASKS` and `EMBED_DIAG_TASKS` task registries

---

## Monitoring & Outputs

- **W&B logging:** Both scripts default to `EVAL_WANDB_PROJECT`. Override with `--project_name` or disable W&B via `--trainer.callbacks.wandb.enabled=False`.
Expand Down
10 changes: 7 additions & 3 deletions olmoearth_pretrain/evals/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from olmoearth_pretrain.evals.studio_ingest.registry import get_dataset_entry

from .breizhcrops import BreizhCropsDataset
from .configs import dataset_to_config
from .floods_dataset import Sen1Floods11Dataset
from .geobench_dataset import GeobenchDataset
from .mados_dataset import MADOSDataset
Expand Down Expand Up @@ -45,13 +46,16 @@ def get_eval_dataset(
**kwargs: Any,
) -> Dataset:
"""Retrieve an eval dataset from the dataset name."""
if eval_dataset == "pretrain_subset":
if eval_dataset.startswith("pretrain_subset"):
patch_size = kwargs.get("pretrain_patch_size", 4)
config = dataset_to_config(eval_dataset)
hw_p = config.height_width // patch_size
return PretrainSubsetDataset(
h5py_dir=kwargs["h5py_dir"],
training_modalities=kwargs.get("training_modalities", input_modalities),
max_samples=kwargs.get("max_samples", 512),
patch_size=kwargs.get("pretrain_patch_size", 4),
hw_p=kwargs.get("pretrain_hw_p", 8),
patch_size=patch_size,
hw_p=hw_p,
seed=kwargs.get("pretrain_seed", 42),
)
elif eval_dataset.startswith("m-"):
Expand Down
32 changes: 25 additions & 7 deletions olmoearth_pretrain/evals/datasets/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def get_eval_mode(task_type: TaskType) -> str:
"""Get the eval mode for a given task type."""
if task_type == TaskType.CLASSIFICATION:
return "knn"
elif task_type == TaskType.DIAGNOSTIC:
return "embedding_diagnostics"
else:
return "linear_probe"

Expand Down Expand Up @@ -48,18 +50,34 @@ def from_dict(cls, d: dict[str, Any]) -> "EvalDatasetConfig":
return cls(**d)


_PRETRAIN_SUBSET_MODALITIES = [
Comment thread
Hgherzog marked this conversation as resolved.
Modality.SENTINEL2_L2A.name,
Modality.SENTINEL1.name,
Modality.LANDSAT.name,
]

DATASET_TO_CONFIG = {
# Dummy config — only used for embedding diagnostics, not actual classification.
# Pretrain subset configs for embedding/tiling diagnostics at different spatial sizes.
# Uses DIAGNOSTIC so the eval wrapper preserves spatial dims [N, H, W, D].
**{
f"pretrain_subset_{px}": EvalDatasetConfig(
task_type=TaskType.DIAGNOSTIC,
imputes=[],
num_classes=1,
is_multilabel=False,
height_width=px,
supported_modalities=_PRETRAIN_SUBSET_MODALITIES,
)
for px in (64, 128)
},
# Backward-compat alias (defaults to 128px).
"pretrain_subset": EvalDatasetConfig(
task_type=TaskType.CLASSIFICATION,
task_type=TaskType.DIAGNOSTIC,
imputes=[],
num_classes=1,
is_multilabel=False,
supported_modalities=[
Modality.SENTINEL2_L2A.name,
Modality.SENTINEL1.name,
Modality.LANDSAT.name,
],
height_width=128,
supported_modalities=_PRETRAIN_SUBSET_MODALITIES,
),
"m-eurosat": EvalDatasetConfig(
task_type=TaskType.CLASSIFICATION,
Expand Down
3 changes: 2 additions & 1 deletion olmoearth_pretrain/evals/datasets/pretrain_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,6 @@ def __getitem__(self, idx: int) -> tuple[MaskedOlmoEarthSample, torch.Tensor]:
)
_, sample = self._dataset[args]
masked = MaskedOlmoEarthSample.from_olmoearthsample(sample)
dummy_label = torch.tensor(0, dtype=torch.long)
pixel_size = self.hw_p * self.patch_size
dummy_label = torch.zeros(pixel_size, pixel_size, dtype=torch.long)
return masked, dummy_label
Loading
Loading