Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
9a05cc4
add pixel projection
favyen2 Apr 22, 2026
7806fa3
Merge remote-tracking branch 'origin/main' into favyen/20260422-add-h…
favyen2 Apr 28, 2026
a95b7c7
Merge remote-tracking branch 'origin/gabi/masking-ablations' into fav…
favyen2 Apr 28, 2026
2306acc
add tiny scripts
favyen2 Apr 29, 2026
a05bb26
Merge remote-tracking branch 'origin/main' into favyen/20260422-add-h…
favyen2 Apr 30, 2026
9bf2e43
Merge branch 'favyen/20260501-band-dropout-fix' into favyen/20260422-…
favyen2 May 1, 2026
74155e9
Merge branch 'favyen/20260501-band-dropout-fix' into favyen/20260422-…
favyen2 May 1, 2026
b87f8ee
Merge branch 'favyen/20260501-band-dropout-fix' into favyen/20260422-…
favyen2 May 1, 2026
ee9f323
Merge branch 'favyen/20260501-band-dropout-fix' into favyen/20260422-…
favyen2 May 1, 2026
0a5a127
Merge branch 'gabi/masking-ablations' into favyen/20260422-add-hidden…
gabrieltseng May 2, 2026
2431813
Add ft launch script
gabrieltseng May 2, 2026
5678f0c
tmp
gabrieltseng May 2, 2026
8e64227
Try hidden1 but also with one extra 768->768 after patchification.
uakfdotb May 3, 2026
9ba6f6f
Update Beaker budget (ai2/es-platform -> ai2/atec-olmoearth)
uakfdotb May 3, 2026
c12e7bb
Merge branch 'favyen/20260502-beaker-budget' into favyen/20260422-add…
uakfdotb May 3, 2026
886ed5f
Merge branch 'favyen/20260422-add-hidden-layer-to-initial-projection'…
uakfdotb May 3, 2026
e0aae5e
24 head experiment
uakfdotb May 3, 2026
7e928b5
Merge remote-tracking branch 'origin/main' into favyen/20260422-add-h…
favyen2 May 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/Setup-Internal.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Create a GitHub token that can clone this repo on Beaker. Generate a token [here

```bash
beaker config set default_workspace ai2/earth-systems
beaker workspace set-budget ai2/earth-systems ai2/es-platform
beaker workspace set-budget ai2/earth-systems ai2/atec-olmoearth
```

### 3. Set Beaker Secrets
Expand Down Expand Up @@ -144,7 +144,7 @@ Evaluation datasets have default paths configured in [`olmoearth_pretrain/evals/
## Beaker Information

**Quick Reference:**
- **Budget:** `ai2/es-platform`
- **Budget:** `ai2/atec-olmoearth`
- **Workspace:** `ai2/earth-systems`
- **Weka:** `weka://dfive-default`

Expand Down
2 changes: 1 addition & 1 deletion olmoearth_pretrain/dataset_creation/internal_docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ A Beaker session can be used to run most of the window creation and data
materialization steps:

```
beaker session create --budget ai2/es-platform --workspace ai2/earth-systems --priority high --gpus 1 --shared-memory 128GiB --bare --mount src=weka,ref=dfive-default,dst=/weka/dfive-default
beaker session create --budget ai2/atec-olmoearth --workspace ai2/earth-systems --priority high --gpus 1 --shared-memory 128GiB --bare --mount src=weka,ref=dfive-default,dst=/weka/dfive-default
```

The only exception is for Sentinel-1 and Sentinel-2 L2A, where it may be desirable to
Expand Down
2 changes: 1 addition & 1 deletion olmoearth_pretrain/inference_benchmarking/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Shared constants important for inference throughput benchmarking."""

# BEAKER-LAND
BEAKER_BUDGET = "ai2/es-platform"
BEAKER_BUDGET = "ai2/atec-olmoearth"
BEAKER_WORKSPACE = "ai2/earth-systems"
WEKA_BUCKET = "dfive-default"
BEAKER_TASK_PRIORITY = "normal"
Expand Down
2 changes: 1 addition & 1 deletion olmoearth_pretrain/internal/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)

logger = logging.getLogger(__name__)
BUDGET = "ai2/es-platform"
BUDGET = "ai2/atec-olmoearth"
WORKSPACE = "ai2/earth-systems"

DEFAULT_OLMOEARTH_PRETRAIN_WEKA_BUCKET = BeakerWekaBucket(
Expand Down
72 changes: 70 additions & 2 deletions olmoearth_pretrain/nn/flexi_patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(
interpolation: str = "bicubic",
antialias: bool = True,
use_linear_patch_embed: bool = True,
patch_embed_hidden_sizes: list[int] | None = None,
post_proj_hidden_sizes: list[int] | None = None,
) -> None:
"""2D image to patch embedding w/ flexible patch sizes.

Expand All @@ -59,6 +61,20 @@ def __init__(
antialias: Whether to apply antialiasing resizing
use_linear_patch_embed: If True, use nn.Linear (reshape + matmul via cuBLAS GEMM).
If False, use nn.Conv2d (required to load checkpoints trained before this flag existed).
patch_embed_hidden_sizes: Optional list of hidden layer widths for a
per-pixel MLP applied BEFORE patchification. If None or empty, the
projection is a single nn.Linear over the flattened patch (current
behavior). Otherwise, each pixel's ``in_chans`` channels are first
mapped through
Linear(in_chans, h[0]) -> ReLU -> Linear(h[0], h[1]) -> ReLU -> ... -> Linear(h[-2], h[-1]) -> ReLU
(weights shared across all pixels), producing an ``H x W x h[-1]``
feature map, which is then patchified and projected to
``embedding_size`` with a final Linear(h[-1] * p_h * p_w, embedding_size).
Only supported when use_linear_patch_embed=True.
post_proj_hidden_sizes: Optional list of hidden layer widths for an MLP
applied AFTER the patch projection (``self.proj``). Each entry adds a
ReLU -> Linear(prev, h) layer. Applied before the norm layer.
Only supported when use_linear_patch_embed=True.
"""
super().__init__()

Expand All @@ -70,15 +86,57 @@ def __init__(
base_patch_size_at_16 * modality_spec.image_tile_size_factor
)

if patch_embed_hidden_sizes and not use_linear_patch_embed:
raise ValueError(
"patch_embed_hidden_sizes requires use_linear_patch_embed=True"
)
if post_proj_hidden_sizes and not use_linear_patch_embed:
raise ValueError(
"post_proj_hidden_sizes requires use_linear_patch_embed=True"
)

p_h, p_w = self.base_patch_size
self.pixel_proj: nn.Sequential | None = None
if use_linear_patch_embed:
# Reshape patches to (p1 p2 c) then project — hits cuBLAS GEMM (always fast
# on TensorCores) vs Conv2d which hits slow cuDNN paths for small in_chans.
self.proj = nn.Linear(in_chans * p_h * p_w, embedding_size, bias=bias)
if patch_embed_hidden_sizes:
# Per-pixel MLP that maps ``in_chans -> h[0] -> ... -> h[-1]`` with
# ReLU activations (weights shared across every pixel). Applied before
# patchification; the per-patch projection below then maps the flattened
# ``h[-1] * p_h * p_w`` pixel features to ``embedding_size``.
pixel_layers: list[nn.Module] = []
prev = in_chans
for h in patch_embed_hidden_sizes:
pixel_layers.append(nn.Linear(prev, h, bias=bias))
pixel_layers.append(nn.ReLU(inplace=True))
prev = h
self.pixel_proj = nn.Sequential(*pixel_layers)
for m in self.pixel_proj:
if isinstance(m, nn.Linear):
m._skip_custom_init = True
patch_in_features = prev * p_h * p_w
else:
patch_in_features = in_chans * p_h * p_w
self.proj = nn.Linear(patch_in_features, embedding_size, bias=bias)
# Keep PyTorch's default nn.Linear initialization (kaiming_uniform_) for
# patch projection to match prior Conv2d behavior; overriding this with
# encoder-level Xavier init correlated with a PASTIS regression.
self.proj._skip_custom_init = True

# Post-projection MLP: ReLU -> Linear(prev, h) for each h.
self.post_proj: nn.Sequential | None = None
if post_proj_hidden_sizes:
post_layers: list[nn.Module] = []
prev_dim = embedding_size
for h in post_proj_hidden_sizes:
post_layers.append(nn.ReLU(inplace=True))
post_layers.append(nn.Linear(prev_dim, h, bias=bias))
prev_dim = h
self.post_proj = nn.Sequential(*post_layers)
for m in self.post_proj:
if isinstance(m, nn.Linear):
m._skip_custom_init = True
else:
self.proj = nn.Conv2d(
in_chans,
Expand Down Expand Up @@ -119,8 +177,18 @@ def _project_linear(
) -> Tensor:
"""Project patches using nn.Linear (reshape → cuBLAS GEMM → reshape)."""
p_h, p_w = self.base_patch_size
x = rearrange(x, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p_h, p2=p_w)
if self.pixel_proj is not None:
# Per-pixel MLP over channels (weights shared across all pixels):
# [b, c, H, W] -> [b, H, W, c] -> MLP -> [b, H, W, h[-1]]
# Then patchify: [b, h_patches, w_patches, p_h * p_w * h[-1]]
x = rearrange(x, "b c h w -> b h w c")
x = self.pixel_proj(x)
x = rearrange(x, "b (h p1) (w p2) c -> b (h w) (p1 p2 c)", p1=p_h, p2=p_w)
else:
x = rearrange(x, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p_h, p2=p_w)
x = self.proj(x)
if self.post_proj is not None:
x = self.post_proj(x)
if has_time_dim:
return rearrange(
x,
Expand Down
38 changes: 38 additions & 0 deletions olmoearth_pretrain/nn/flexi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def __init__(
band_dropout_rate: float = 0.0,
random_band_dropout: bool = False,
band_dropout_modalities: list[str] | None = None,
patch_embed_hidden_sizes: list[int] | None = None,
post_proj_hidden_sizes: list[int] | None = None,
):
"""Initialize the patch embeddings.

Expand All @@ -213,6 +215,19 @@ def __init__(
and acts as stronger augmentation. Default: False (fixed rate).
band_dropout_modalities: If provided, only apply band dropout to these
modalities. If None, apply to all modalities. Default: None.
patch_embed_hidden_sizes: Optional list of hidden layer widths for a
per-pixel MLP applied BEFORE patchification in the spatial
FlexiPatchEmbed. If None or empty, the projection is a single nn.Linear
over the flattened patch (current behavior). Otherwise, each pixel's
channel vector is mapped via an MLP with ReLU activations (weights
shared across all pixels), producing an H x W x h[-1] feature map
that is then patchified and projected to embedding_size. Only applies
to the spatial branch (FlexiPatchEmbed); the non-spatial nn.Linear
branch is unaffected.
post_proj_hidden_sizes: Optional list of hidden layer widths for an MLP
applied AFTER the patch projection. Each entry adds a
ReLU -> Linear(prev, h) layer, applied before the norm. Only applies
to the spatial branch (FlexiPatchEmbed).
"""
super().__init__()
self.max_patch_size = max_patch_size
Expand All @@ -223,6 +238,8 @@ def __init__(
self.band_dropout_rate = band_dropout_rate
self.random_band_dropout = random_band_dropout
self.band_dropout_modalities = band_dropout_modalities
self.patch_embed_hidden_sizes = patch_embed_hidden_sizes
self.post_proj_hidden_sizes = post_proj_hidden_sizes
# TODO: want to be able to remove certain bands and modalities
self.per_modality_embeddings = nn.ModuleDict({})

Expand Down Expand Up @@ -286,6 +303,8 @@ def _get_patch_embedding_module_for_modality(self, modality: str) -> nn.Module:
base_patch_size_at_16=self.max_patch_size,
modality_spec=modality_spec,
use_linear_patch_embed=self.use_linear_patch_embed,
patch_embed_hidden_sizes=self.patch_embed_hidden_sizes,
post_proj_hidden_sizes=self.post_proj_hidden_sizes,
)
for idx, channel_set_idxs in enumerate(bandset_indices)
}
Expand Down Expand Up @@ -1108,6 +1127,8 @@ def __init__(
band_dropout_rate: float = 0.0,
random_band_dropout: bool = False,
band_dropout_modalities: list[str] | None = None,
patch_embed_hidden_sizes: list[int] | None = None,
post_proj_hidden_sizes: list[int] | None = None,
):
"""Initialize the encoder.

Expand Down Expand Up @@ -1141,6 +1162,17 @@ def __init__(
random_band_dropout: If True, sample dropout rate from Uniform(0, band_dropout_rate).
band_dropout_modalities: If provided, only apply band dropout to these
modalities. If None, apply to all modalities. Default: None.
patch_embed_hidden_sizes: Optional list of hidden layer widths for a
per-pixel MLP applied BEFORE patchification in the spatial patch
projection. If None or empty, the projection is a single nn.Linear
over the flattened patch (current behavior). Otherwise, each pixel's
``in_chans`` channel vector is mapped via
Linear(in_chans, h[0]) -> ReLU -> ... -> Linear(h[-2], h[-1]) -> ReLU
(weights shared across all pixels), and the resulting H x W x h[-1]
feature map is patchified and projected to embedding_size.
post_proj_hidden_sizes: Optional list of hidden layer widths for an MLP
applied AFTER the patch projection. Each entry adds a
ReLU -> Linear(prev, h) layer, applied before the norm.
"""
self.tokenization_config = tokenization_config or TokenizationConfig()
super().__init__(
Expand Down Expand Up @@ -1174,6 +1206,8 @@ def __init__(
self.band_dropout_rate = band_dropout_rate
self.random_band_dropout = random_band_dropout
self.band_dropout_modalities = band_dropout_modalities
self.patch_embed_hidden_sizes = patch_embed_hidden_sizes
self.post_proj_hidden_sizes = post_proj_hidden_sizes
self.patch_embeddings = MultiModalPatchEmbeddings(
self.supported_modality_names,
self.max_patch_size,
Expand All @@ -1183,6 +1217,8 @@ def __init__(
band_dropout_rate=0.0,
random_band_dropout=self.random_band_dropout,
band_dropout_modalities=self.band_dropout_modalities,
patch_embed_hidden_sizes=self.patch_embed_hidden_sizes,
post_proj_hidden_sizes=self.post_proj_hidden_sizes,
)
self.output_embedding_size = output_embedding_size
# If output_embedding_size is set, project tokens to that size after attention
Expand Down Expand Up @@ -2082,6 +2118,8 @@ class EncoderConfig(Config):
band_dropout_rate: float = 0.0
random_band_dropout: bool = False
band_dropout_modalities: list[str] | None = None
patch_embed_hidden_sizes: list[int] | None = None
post_proj_hidden_sizes: list[int] | None = None

def __post_init__(self) -> None:
"""Coerce raw dicts to TokenizationConfig for old checkpoint compatibility."""
Expand Down
Loading
Loading