Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
9065da9
Add script courtesy of @yawenzzzz
gabrieltseng Mar 30, 2026
2069b72
Add more evals
gabrieltseng Mar 30, 2026
4c9cae4
less random more time
gabrieltseng Mar 30, 2026
b664309
less time more random
gabrieltseng Mar 30, 2026
06689bd
hack
gabrieltseng Mar 31, 2026
e499f89
ctrl + z
gabrieltseng Mar 31, 2026
6318d6e
random only
gabrieltseng Mar 31, 2026
eeae879
back to normal
gabrieltseng Apr 1, 2026
87df9bf
:facepalm:
gabrieltseng Apr 1, 2026
09a130d
be consistent
gabrieltseng Apr 1, 2026
4375d6e
same fix
gabrieltseng Apr 1, 2026
0177de4
Update script
gabrieltseng Apr 2, 2026
4bdaeb5
oil spill eval is massive leading to OOM - lets skip it
gabrieltseng Apr 3, 2026
186ad56
Add comment explaining
gabrieltseng Apr 3, 2026
e4190e5
Add single bandset models to flop calculations
gabrieltseng Apr 7, 2026
23a1a8e
Add loss ablation
gabrieltseng Apr 13, 2026
3d8ee10
Add ablation
gabrieltseng Apr 13, 2026
3e2d8fd
rename
gabrieltseng Apr 13, 2026
7c81471
l1/cosine similarity single band set
favyen2 Apr 14, 2026
7dd15c8
use the right base script
favyen2 Apr 14, 2026
5846f09
fix
favyen2 Apr 14, 2026
cafc95a
update script defaults, add base sweeps
gabrieltseng Apr 21, 2026
8ef6d20
Add lr 0.0003
gabrieltseng Apr 21, 2026
c9ff50b
run on 8 gpus, dont repeat the original run
gabrieltseng Apr 21, 2026
04dc50b
nano and tiny runs
gabrieltseng Apr 22, 2026
28c6d25
Merge branch 'main' into gabi/masking-ablations
gabrieltseng Apr 22, 2026
6251859
update to vectorized loss
gabrieltseng Apr 22, 2026
6e4b591
update to vec loss
gabrieltseng Apr 22, 2026
7d11df1
update run names
gabrieltseng Apr 22, 2026
1dee7a2
fix
gabrieltseng Apr 22, 2026
b759b25
Merge branch 'gabi/masking-ablations' of github.com:allenai/helios in…
favyen2 Apr 22, 2026
d42fe6b
add loss ablation
gabrieltseng Apr 22, 2026
34b22d2
add ndvi
gabrieltseng Apr 22, 2026
7221d29
add 0.001 lrs for nano and tiny
gabrieltseng Apr 24, 2026
ede4758
more base, lower lr
gabrieltseng Apr 24, 2026
8f1b403
update st_model.py to be compatible with latest changes
favyen2 Apr 24, 2026
c1bd50a
fix
favyen2 Apr 24, 2026
7597152
Merge branch 'gabi/masking-ablations' into favyen/20260424-stmodel
favyen2 Apr 24, 2026
8b5e8ea
Merge branch 'favyen/20260414-misc-single-bandset' into favyen/202604…
favyen2 Apr 24, 2026
feaadbc
add cosine similarity st model script
favyen2 Apr 24, 2026
d8837d2
handle fast_pass sort of but not really
uakfdotb Apr 25, 2026
08f4d21
use vectorized loss
gabrieltseng Apr 25, 2026
54e936b
:facepalm:
gabrieltseng Apr 25, 2026
a7bda6f
lower beta2
gabrieltseng Apr 25, 2026
92014a1
add windowed attention size 3 scripts
favyen2 Apr 27, 2026
041e4ec
add script for 3x3 cosine similarity windowed attention with full dec…
favyen2 Apr 27, 2026
11545ec
fix
favyen2 Apr 27, 2026
d31133c
make the run really long to avoid triggering finished run, like in fi…
gabrieltseng Apr 30, 2026
6f1ef4e
Merge remote-tracking branch 'origin/main' into favyen/20260424-stmodel
favyen2 Apr 30, 2026
ed34fc0
get best run over seed
gabrieltseng May 1, 2026
a8b3a8f
tmp
gabrieltseng May 1, 2026
0dd3007
knn and lp script
gabrieltseng May 1, 2026
1bae61d
update cluster
gabrieltseng May 1, 2026
72eff63
high
gabrieltseng May 1, 2026
e1a2345
no need to rerun
gabrieltseng May 1, 2026
bcf1f82
plus warning
gabrieltseng May 1, 2026
9e158bc
Fix band dropout being applied during fine-tuning bug that Gabi ident…
favyen2 May 1, 2026
84bb24b
Merge branch 'favyen/20260501-band-dropout-fix' into favyen/20260424-…
favyen2 May 1, 2026
7c2dfb6
Merge remote-tracking branch 'origin/gabi/masking-ablations' into fav…
favyen2 May 1, 2026
bd09d59
apply band dropout fix for st model
favyen2 May 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions olmoearth_pretrain/evals/finetune/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ def run_finetune_eval(
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)

if hasattr(model, "disable_band_dropout"):
prev_rate = getattr(
getattr(model, "patch_embeddings", None), "band_dropout_rate", 0.0
)
if prev_rate > 0.0:
logger.warning(
f"Overriding band_dropout_rate from {prev_rate} to 0.0 for finetuning."
)
model.disable_band_dropout()

ft = BackboneWithHead(
model=model,
task_type=task_config.task_type,
Expand Down
32 changes: 17 additions & 15 deletions olmoearth_pretrain/internal/all_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def load_user_module(path: str) -> Any:
embedding_batch_size=32,
probe_batch_size=8,
num_workers=2,
pooling_type=PoolingType.MAX,
pooling_type=PoolingType.MEAN,
norm_stats_from_pretrained=True,
probe_lr=0.1,
eval_interval=Duration.epochs(50),
Expand Down Expand Up @@ -385,20 +385,22 @@ def load_user_module(path: str) -> Any:
epochs=50,
eval_mode=EvalMode.LINEAR_PROBE,
),
"oil_spill_detection": DownstreamTaskConfig(
dataset="oil_spill_detection",
embedding_batch_size=128,
probe_batch_size=8,
num_workers=8,
pooling_type=PoolingType.MEAN,
norm_stats_from_pretrained=True,
norm_method=NormMethod.NORM_NO_CLIP_2_STD,
probe_lr=0.01,
eval_interval=Duration.epochs(10),
input_modalities=[Modality.SENTINEL1.name],
epochs=50,
eval_mode=EvalMode.LINEAR_PROBE,
),
# this eval is very large and can lead to
# OOM errors. Skipping for now.
# "oil_spill_detection": DownstreamTaskConfig(
# dataset="oil_spill_detection",
# embedding_batch_size=128,
# probe_batch_size=8,
# num_workers=8,
# pooling_type=PoolingType.MEAN,
# norm_stats_from_pretrained=True,
# norm_method=NormMethod.NORM_NO_CLIP_2_STD,
# probe_lr=0.01,
# eval_interval=Duration.epochs(10),
# input_modalities=[Modality.SENTINEL1.name],
# epochs=50,
# eval_mode=EvalMode.LINEAR_PROBE,
# ),
}

EMBED_DIAG_TASKS = {
Expand Down
7 changes: 7 additions & 0 deletions olmoearth_pretrain/internal/full_eval_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,10 @@ def _get_pooling_type_str(pooling_type: str) -> str:


LAUNCH_OVERRIDES = "--launch.priority=high --launch.num_gpus=1 --launch.task_name=eval"
# Overwrite the max duration to enable eval of the last step of the checkpoint
MAX_DURATION_OVERRIDE = (
"--trainer.max_duration.value=10000000 --trainer.max_duration.unit=steps"
)


def _get_env_prefix(args: argparse.Namespace, module_path: str) -> str:
Expand Down Expand Up @@ -989,6 +993,8 @@ def build_commands(args: argparse.Namespace, extra_cli: list[str]) -> list[str]:
commands_to_run_new.append(cmd)
commands_to_run = commands_to_run_new

commands_to_run = [f"{cmd} {MAX_DURATION_OVERRIDE}" for cmd in commands_to_run]

# Filter out skipped tasks if task-skip-names is provided
if args.task_skip_names:
skip_names = [name.strip() for name in args.task_skip_names.split(",")]
Expand Down Expand Up @@ -1133,6 +1139,7 @@ def main() -> None:
args, extra_cli = parser.parse_known_args()

commands_to_run = build_commands(args, extra_cli)
commands_to_run = commands_to_run[11:]

logger.info(f"Running {len(commands_to_run)} commands")
for cmd in commands_to_run:
Expand Down
15 changes: 11 additions & 4 deletions olmoearth_pretrain/nn/flexi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,9 @@ def __init__(
self.max_patch_size = max_patch_size
self.embedding_size = embedding_size
self.use_linear_patch_embed = use_linear_patch_embed
# Configured rate; remains inactive until ``enable_band_dropout`` is called.
# Default is disabled so fine-tuning never applies band dropout unless the
# caller (e.g. pretraining online encoder) explicitly enables it.
self.band_dropout_rate = band_dropout_rate
self.random_band_dropout = random_band_dropout
self.band_dropout_modalities = band_dropout_modalities
Expand All @@ -1177,7 +1180,7 @@ def __init__(
self.embedding_size,
tokenization_config=self.tokenization_config,
use_linear_patch_embed=self.use_linear_patch_embed,
band_dropout_rate=self.band_dropout_rate,
band_dropout_rate=0.0,
random_band_dropout=self.random_band_dropout,
band_dropout_modalities=self.band_dropout_modalities,
)
Expand Down Expand Up @@ -1209,9 +1212,13 @@ def __init__(
if self.has_register_tokens:
self._init_register_tokens()

def disable_band_dropout(self) -> None:
"""Disable band dropout (e.g. for target/EMA encoder)."""
self.patch_embeddings.band_dropout_rate = 0.0
def enable_band_dropout(self) -> None:
"""Enable band dropout using the configured rate.

Band dropout is disabled by default so it never activates during
fine-tuning. Call this only on the online encoder during pretraining.
"""
self.patch_embeddings.band_dropout_rate = self.band_dropout_rate

def _init_register_tokens(self) -> None:
"""Initialize the register tokens."""
Expand Down
9 changes: 6 additions & 3 deletions olmoearth_pretrain/nn/latent_mim.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def __init__(
self.target_encoder = deepcopy(self.encoder)
for p in self.target_encoder.parameters():
p.requires_grad = False
# Disable band dropout on target encoder so it always sees full spectral info.
if hasattr(self.target_encoder, "disable_band_dropout"):
self.target_encoder.disable_band_dropout()
# Band dropout is off by default so it never activates during fine-tuning;
# turn it on for the online encoder only during pretraining. The target
# encoder (deepcopy made above) keeps the disabled state and always sees
# full spectral info.
if hasattr(self.encoder, "enable_band_dropout"):
self.encoder.enable_band_dropout()

def forward(
self, x: MaskedOlmoEarthSample, patch_size: int
Expand Down
78 changes: 73 additions & 5 deletions olmoearth_pretrain/nn/st_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,10 @@ def __init__(
fuse_using_cross_attn: bool = True,
tokenization_config: TokenizationConfig | None = None,
use_linear_patch_embed: bool = True,
band_dropout_rate: float = 0.0,
random_band_dropout: bool = False,
band_dropout_modalities: list[str] | None = None,
output_embedding_size: int | None = None,
):
"""Initialize the encoder.

Expand Down Expand Up @@ -779,6 +783,10 @@ def __init__(
tokenization_config: Optional config for custom band groupings
use_linear_patch_embed: If True, use nn.Linear for patch projection (faster).
Set False to load checkpoints trained before this flag existed (Conv2d weights).
band_dropout_rate: Probability of dropping each band channel during training.
random_band_dropout: If True, sample dropout rate from Uniform(0, band_dropout_rate).
band_dropout_modalities: If provided, only apply band dropout to these modalities.
output_embedding_size: If set, project tokens to this size after attention.
"""
self.tokenization_config = tokenization_config or TokenizationConfig()
super().__init__(
Expand All @@ -801,16 +809,38 @@ def __init__(
self.fuse_layers = fuse_layers
self.layer_attention_modes = layer_attention_modes
self.fuse_using_cross_attn = fuse_using_cross_attn
# Configured rate; remains inactive until ``enable_band_dropout`` is called.
# Default is disabled so fine-tuning never applies band dropout unless the
# caller (e.g. pretraining online encoder) explicitly enables it.
self.band_dropout_rate = band_dropout_rate
self.random_band_dropout = random_band_dropout
self.band_dropout_modalities = band_dropout_modalities
self.patch_embeddings = MultiModalPatchEmbeddings(
self.supported_modality_names,
self.max_patch_size,
self.embedding_size,
tokenization_config=self.tokenization_config,
use_linear_patch_embed=use_linear_patch_embed,
band_dropout_rate=0.0,
random_band_dropout=self.random_band_dropout,
band_dropout_modalities=self.band_dropout_modalities,
)
self.output_embedding_size = output_embedding_size
# If output_embedding_size is set, project tokens to that size after attention
self.embedding_projector: ProjectAndAggregate | None = None
if output_embedding_size is not None:
self.embedding_projector = ProjectAndAggregate(
embedding_size=self.embedding_size,
num_layers=1,
output_embedding_size=output_embedding_size,
only_project=True,
)
final_embedding_size = output_embedding_size
else:
final_embedding_size = self.embedding_size
# TODO: add backwards compatibility without the project and aggregate module
self.project_and_aggregate = ProjectAndAggregate(
embedding_size=self.embedding_size,
embedding_size=final_embedding_size,
num_layers=num_projection_layers,
aggregate_then_project=aggregate_then_project,
)
Expand Down Expand Up @@ -1111,24 +1141,37 @@ def apply_attn(

return x

def enable_band_dropout(self) -> None:
"""Enable band dropout using the configured rate.

Band dropout is disabled by default so it never activates during
fine-tuning. Call this only on the online encoder during pretraining.
"""
self.patch_embeddings.band_dropout_rate = self.band_dropout_rate

def forward(
self,
x: MaskedOlmoEarthSample,
patch_size: int,
input_res: int = BASE_GSD,
token_exit_cfg: dict | None = None,
) -> tuple[TokensAndMasks, Tensor]:
fast_pass: bool = False,
) -> dict[str, Any]:
"""Process masked input samples into token representations.

Args:
x: Masked input sample containing the data to be encoded
patch_size: Size of patches to divide the input into
input_res: Resolution of the input data
token_exit_cfg: Configuration for token exit
fast_pass: Whether to always pass None as the mask to the transformer, this enables torch based flash attention, and skips mask construciton and sorting

Returns:
TokensAndMasks containing the encoded representations and their masks
Dict with 'tokens_and_masks' and 'project_aggregated' keys.
"""
if fast_pass and token_exit_cfg is not None:
raise ValueError("token_exit_cfg cannot be set when fast_pass is True")

# TODO: Add step to validate the exit config is valid
patchified_tokens_and_masks = self.patch_embeddings.forward(x, patch_size)
if token_exit_cfg is None or any(
Expand All @@ -1142,7 +1185,19 @@ def forward(
token_exit_cfg=token_exit_cfg,
)
output = TokensAndMasks(**patchified_tokens_and_masks)
return output, self.project_and_aggregate(output)

# Project to output_embedding_size if configured
if self.embedding_projector is not None:
output = self.embedding_projector(output)

output_dict: dict[str, Any] = {
"tokens_and_masks": output,
}

if not fast_pass:
output_dict["project_aggregated"] = self.project_and_aggregate(output)

return output_dict

def apply_fsdp(self, **fsdp_kwargs: Any) -> None:
"""Apply FSDP to the model."""
Expand Down Expand Up @@ -1444,7 +1499,7 @@ def forward(
Returns:
TokensAndMasks containing the predicted tokens and their masks
"""
decoder_emedded_dict = x.as_dict(include_nones=True)
decoder_emedded_dict = x.as_dict()
# Apply Input Norms and encoder to decoder embeds to each modality
available_modalities = x.modalities
modalities_to_process = get_modalities_to_process(
Expand Down Expand Up @@ -1519,6 +1574,10 @@ class STEncoderConfig(Config):
fuse_using_cross_attn: bool = True
tokenization_config: TokenizationConfig | None = None
use_linear_patch_embed: bool = True
output_embedding_size: int | None = None
band_dropout_rate: float = 0.0
random_band_dropout: bool = False
band_dropout_modalities: list[str] | None = None

def __post_init__(self) -> None:
"""Coerce raw dicts to TokenizationConfig for old checkpoint compatibility."""
Expand All @@ -1533,6 +1592,15 @@ def validate(self) -> None:
for modality in self.supported_modalities:
if modality not in Modality.values():
raise ValueError(f"Modality {modality} is not supported")
if self.band_dropout_modalities is not None:
unknown = set(self.band_dropout_modalities) - set(
self.supported_modality_names
)
if unknown:
raise ValueError(
f"band_dropout_modalities contains modalities not in "
f"supported_modality_names: {unknown}"
)
if self.tokenization_config is not None:
self.tokenization_config.validate()

Expand Down
26 changes: 26 additions & 0 deletions olmoearth_pretrain/train/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,32 @@ def compute(
return F.l1_loss(pred, target)


@LOSS_REGISTRY.register("cosine_similarity")
class CosineSimilarityLoss(Loss):
"""Negative mean cosine similarity between predicted and target decoder tokens."""

name = "CosineSim"

def compute(
self, predictions: TokensAndMasks, targets: TokensAndMasks, **kwargs: Any
) -> Tensor:
"""Compute negative cosine similarity loss between predictions and targets.

Args:
predictions: Model predictions.
targets: Ground truth targets.
**kwargs: Additional keyword arguments.

Returns:
The computed loss value (negative mean cosine similarity).
"""
all_preds, all_masks = predictions.flatten_all_tokens_and_masks()
all_targets = targets.flatten_all_tokens_and_masks()[0]
pred = all_preds[all_masks == MaskValue.DECODER.value]
target = all_targets[all_masks == MaskValue.DECODER.value]
return -F.cosine_similarity(pred, target, dim=-1).mean()


@LOSS_REGISTRY.register("l2")
class L2Loss(Loss):
"""Loss function for L2 (mean squared error)."""
Expand Down
Loading
Loading