diff --git a/--config-path=configs b/--config-path=configs new file mode 100644 index 00000000..e69de29b diff --git a/.gitignore b/.gitignore index 95283c5b..27481beb 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,10 @@ exploratory/data outputs lightning_logs datasets +slurm-*.out +slurm-*.stats +logs/ +wandb/ slurm_scripts logs local_hydra/local_experiment/*.yaml diff --git a/EOF b/EOF new file mode 100644 index 00000000..e69de29b diff --git a/Qien_Code/data_debug.ipynb b/Qien_Code/data_debug.ipynb new file mode 100644 index 00000000..e69de29b diff --git a/Qien_Code/download_raw_osisaf.py b/Qien_Code/download_raw_osisaf.py new file mode 100644 index 00000000..2fd7b2c5 --- /dev/null +++ b/Qien_Code/download_raw_osisaf.py @@ -0,0 +1,16 @@ +# save_osisaf_raw.py +import os +import xarray as xr + +OPENDAP_URL = "https://thredds.met.no/thredds/dodsC/osisaf/met.no/reprocessed/ice/conc_450a1_nh_agg" +RAW_DIR = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/osisaf_raw" + +os.makedirs(RAW_DIR, exist_ok=True) + +print("Opening OPENDAP (raw)...") +ds = xr.open_dataset(OPENDAP_URL) + +out_path = os.path.join(RAW_DIR, "osisaf_nh_raw.nc") +ds.to_netcdf(out_path) + +print(f"Saved raw dataset → {out_path}") diff --git a/Qien_Code/eval_epd_masked.sh b/Qien_Code/eval_epd_masked.sh new file mode 100644 index 00000000..58ea0d9d --- /dev/null +++ b/Qien_Code/eval_epd_masked.sh @@ -0,0 +1,33 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 1:00:00 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --job-name eval_masked_epd + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/.conda/envs/autocast/bin/python -m autocast.scripts.eval.encoder_processor_decoder \ + --config-dir outputs/seaice/epd_flow_pixels_in2_out1_masked__selectedyears/2026-02-13_15-12-29 \ + --config-name resolved_config \ + hydra.run.dir=outputs/seaice/epd_flow_pixels_in2_out1_masked__selectedyears/2026-02-13_15-12-29/eval_run \ + datamodule.data_path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_osisaf_selectedyears \ + eval.checkpoint=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast/outputs/seaice/epd_flow_pixels_in2_out1_masked__selectedyears/2026-02-13_15-12-29/encoder_processor_decoder.ckpt \ + eval.free_running_only=true \ + eval.batch_indices=[0] \ + eval.video_dir=outputs/seaice/epd_flow_pixels_in2_out1_masked__selectedyears/2026-02-13_15-12-29/eval_videos \ + eval.video_format=mp4 \ + eval.fps=5 \ + eval.device=cuda \ + +model.processor.mask_path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/land_mask.pt diff --git a/Qien_Code/get_osisaf_data.py b/Qien_Code/get_osisaf_data.py new file mode 100644 index 00000000..41751322 --- /dev/null +++ b/Qien_Code/get_osisaf_data.py @@ -0,0 +1,198 @@ +import os +import numpy as np +import xarray as xr +import pandas as pd +import torch + +# ----------------------- +# Config +# ----------------------- +OPENDAP_URL = "https://thredds.met.no/thredds/dodsC/osisaf/met.no/reprocessed/ice/conc_450a1_nh_agg" +OUT_DIR = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all" # change me + +# Split rule (year-based). Adjust as you like. +# By default: last 5 years = test, previous 5 year = valid, rest = train. +N_TEST_YEARS = 5 +N_VALID_YEARS = 5 + +# Keep tensor size manageable +DTYPE = np.float32 # you can switch to np.float16 to halve disk usage +CHUNK_DAYS = 30 # Dask chunking along time + +# ----------------------- +# Helpers +# ----------------------- +def infer_sic_var(ds: xr.Dataset) -> str: + """Try common variable names; fallback to first data_var.""" + candidates = ["ice_conc", "conc", "sic", "ice_concentration"] + for v in candidates: + if v in ds.data_vars: + return v + # fallback: pick first non-empty variable + for v in ds.data_vars: + return v + raise ValueError("No data variables found in dataset.") + +def drop_feb29(da: xr.DataArray) -> xr.DataArray: + """Drop Feb 29 to ensure 365 days per year.""" + t = da["time"].to_index() + mask = ~((t.month == 2) & (t.day == 29)) + return da.isel(time=np.where(mask)[0]) + +# def normalize_units(sic: xr.DataArray) -> xr.DataArray: +# """Convert to [0,1] if necessary.""" +# # Heuristic: if values look like percent (0..100), convert. +# # Use robust quantile to avoid occasional fill values. +# q99 = float(sic.quantile(0.99, skipna=True).compute()) +# if q99 > 1.5: +# sic = sic / 100.0 +# return sic + +def normalize_units(sic): + # Try to infer scale from attributes/units + units = (sic.attrs.get("units", "") or "").lower() + # Many SIC products are in %, some are fraction. + if "%" in units or "percent" in units: + sic = sic / 100.0 + + return sic + +def clean_fill(sic: xr.DataArray) -> xr.DataArray: + """Handle common fill conventions & clip to [0,1].""" + # Convert to float for NaN support + sic = sic.astype("float32") + + # If dataset provides a _FillValue, xarray usually decodes it already, + # but just in case: + fill = sic.attrs.get("_FillValue", None) + if fill is not None: + sic = sic.where(sic != fill) + + # Some OSI SAF products can include impossible values; keep plausible range. + sic = sic.where((sic >= 0.0) & (sic <= 1.0)) + + return sic + +def ensure_order(sic: xr.DataArray) -> xr.DataArray: + """Ensure (time, y, x) ordering (or time, yc, xc etc.).""" + # Try common spatial dim names + spatial_dims = [d for d in sic.dims if d != "time"] + if len(spatial_dims) != 2: + raise ValueError(f"Expected 2 spatial dims besides time, got dims={sic.dims}") + return sic.transpose("time", spatial_dims[0], spatial_dims[1]) + +def year_list(da: xr.DataArray): + years = np.unique(da["time"].dt.year.values) + years = years[~np.isnan(years)].astype(int) + years.sort() + return years.tolist() + +# ----------------------- +# Main +# ----------------------- +def main(): + os.makedirs(OUT_DIR, exist_ok=True) + for split in ["train", "valid", "test"]: + os.makedirs(os.path.join(OUT_DIR, split), exist_ok=True) + + print(f"Opening OPENDAP: {OPENDAP_URL}") + ds = xr.open_dataset(OPENDAP_URL, chunks={"time": CHUNK_DAYS}) + + sic_var = infer_sic_var(ds) + print(f"Using SIC variable: {sic_var}") + sic = ds[sic_var] + + # Basic cleaning pipeline (same spirit as before) + sic = ensure_order(sic) + sic = drop_feb29(sic) + sic = normalize_units(sic) + sic = clean_fill(sic) + + # Confirm spatial shape + T, H, W = sic.sizes["time"], sic.sizes[sic.dims[1]], sic.sizes[sic.dims[2]] + print(f"Full time length after dropping Feb29: T={T}") + print(f"Spatial: H={H}, W={W}") + + years = year_list(sic) + print(f"Years available: {years[0]} .. {years[-1]} (n={len(years)})") + + # Split by year (last years for valid/test by default) + test_years = years[-N_TEST_YEARS:] + valid_years = years[-(N_TEST_YEARS + N_VALID_YEARS):-N_TEST_YEARS] if N_VALID_YEARS > 0 else [] + train_years = [y for y in years if (y not in valid_years and y not in test_years)] + + print(f"Split years:") + print(f" train: {train_years[0]}..{train_years[-1]} (n={len(train_years)})") + if valid_years: + print(f" valid: {valid_years} (n={len(valid_years)})") + print(f" test : {test_years} (n={len(test_years)})") + + def build_and_save(years_subset, split_name): + # Build yearly trajectories: (traj, time=365, H, W, C=1) + traj_list = [] + for y in years_subset: + # Select this year + one = sic.sel(time=str(y)) + + # Reindex to full daily sequence (after Feb29 drop) to detect missing days + t0 = pd.Timestamp(f"{y}-01-01") + t1 = pd.Timestamp(f"{y}-12-31") + full = pd.date_range(t0, t1, freq="D") + full = full[~((full.month == 2) & (full.day == 29))] # drop Feb29 + one = one.reindex(time=full) + + # Must be 365 days + if one.sizes["time"] != 365: + print(f"[WARN] Year {y}: time={one.sizes['time']} != 365, skipping") + continue + + # Compute year into memory (365*432*432 ~ 68M floats -> ~270MB float32) + arr = one.data + if hasattr(arr, "compute"): + arr = arr.compute() + arr = np.asarray(arr, dtype=DTYPE) + + # Add channel dim + arr = arr[..., None] # (365, H, W, 1) + + # This is wrong I think! : replace NaNs with 0, and keep a mask if you want later + # For "same as before", we usually did NaN->0 and rely on implicit land mask. + # arr = np.nan_to_num(arr, nan=0.0) + + traj_list.append(arr) + print(f" built year {y}: {arr.shape} {arr.dtype}") + + if not traj_list: + raise RuntimeError(f"No usable trajectories for split={split_name}") + + data = np.stack(traj_list, axis=0) # (traj, 365, H, W, 1) + print(f"[{split_name}] stacked: {data.shape} dtype={data.dtype}") + + # Save as torch tensor + tensor = torch.from_numpy(data) + out_path = os.path.join(OUT_DIR, split_name, "data.pt") + torch.save(tensor, out_path) + print(f"[{split_name}] saved: {out_path}") + + build_and_save(train_years, "train") + if valid_years: + build_and_save(valid_years, "valid") + build_and_save(test_years, "test") + + print("Done.") + +if __name__ == "__main__": + RAW_DIR = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw" + RAW_FILE = os.path.join(RAW_DIR, "osisaf_nh_sic_reprocessed.nc") + + # Download if missing + if not os.path.exists(RAW_FILE): + os.makedirs(RAW_DIR, exist_ok=True) + print(f"Downloading to {RAW_FILE}...") + # Download logic here (cURL, xarray, etc.) + ds = xr.open_dataset(OPENDAP_URL) + ds.to_netcdf(RAW_FILE) + else: + print(f"Using cached: {RAW_FILE}") + ds = xr.open_dataset(RAW_FILE, chunks={"time": CHUNK_DAYS}) + main() diff --git a/Qien_Code/run_osisaf_download.sh b/Qien_Code/run_osisaf_download.sh new file mode 100644 index 00000000..4148821d --- /dev/null +++ b/Qien_Code/run_osisaf_download.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 2:00:00 # 2 hours (adjust if needed) +#SBATCH --nodes 1 +#SBATCH --gpus 0 # No GPU needed for this +#SBATCH --tasks-per-node 8 # 8 CPUs for parallel ops +#SBATCH --job-name process_osisaf + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 + + +# Activate your virtual environment +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast/Qien_Code + +python get_osisaf_data.py \ No newline at end of file diff --git a/Qien_Code/train_epd_osisaf_masked.sh b/Qien_Code/train_epd_osisaf_masked.sh new file mode 100644 index 00000000..11a7957c --- /dev/null +++ b/Qien_Code/train_epd_osisaf_masked.sh @@ -0,0 +1,39 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 4:00:00 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --mem=64G +#SBATCH --job-name epd_masked_osisaf + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +# Using selected years data (2014-2020: 5 years train, 1 year valid, 1 year test) +# Previous: experiment_name=seaice/epd_flow_pixels_in2_out1_masked__2018_data +# Previous: datamodule.data_path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_osisaf_2018/osisaf_nh_sic_2018 + +/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/.conda/envs/autocast/bin/python -m autocast.scripts.train.encoder_processor_decoder \ + experiment_name=seaice/epd_flow_pixels_in2_out1_masked__selectedyears \ + datamodule=osisaf_nh_sic \ + datamodule.data_path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_osisaf_selectedyears \ + trainer.max_epochs=40 \ + trainer.accelerator=gpu \ + trainer.devices=1 \ + logging.wandb.enabled=true \ + encoder@model.encoder=identity \ + decoder@model.decoder=identity \ + processor@model.processor=masked_flow_matching \ + model.processor.backbone.global_cond_channels=null \ + model.processor.backbone.include_global_cond=false \ + model.processor.mask_path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/land_mask.pt diff --git a/compare_processed_datasets.sh b/compare_processed_datasets.sh new file mode 100644 index 00000000..c494c348 --- /dev/null +++ b/compare_processed_datasets.sh @@ -0,0 +1,115 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 00:30:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --gpus=0 +#SBATCH --mem=32G +#SBATCH --job-name compare_processed +#SBATCH --output=logs/compare_processed_%j.out +#SBATCH --error=logs/compare_processed_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +# Activate conda environment +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import torch +import os + +print("=" * 100) +print("COMPARING TWO PROCESSED DATASETS") +print("=" * 100) + +# Path 1: Full dataset (multi-year, train/valid/test split) +path1_base = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all" +# Path 2: 2018 only dataset +path2_base = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic" + +print("\n" + "=" * 100) +print("DATASET 1: osisaf_nh_sic_all (FULL MULTI-YEAR DATA)") +print("=" * 100) +print(f"Location: {path1_base}") +print(f"Subdirectories: {os.listdir(path1_base)}") + +for split in ["train", "valid", "test"]: + split_path = os.path.join(path1_base, split, "data.pt") + if os.path.exists(split_path): + data = torch.load(split_path, map_location='cpu') + file_size_mb = os.path.getsize(split_path) / (1024**2) + + print(f"\n [{split}]") + print(f" File size: {file_size_mb:.2f} MB") + print(f" Type: {type(data)}") + + if isinstance(data, dict): + print(f" Keys: {list(data.keys())}") + for key, val in data.items(): + if isinstance(val, torch.Tensor): + print(f" {key}: shape {val.shape}, dtype {val.dtype}") + print(f" Value range: [{val.min():.6f}, {val.max():.6f}], mean: {val.mean():.6f}") + elif isinstance(val, list): + print(f" {key}: list of {len(val)} items") + elif isinstance(data, torch.Tensor): + print(f" Shape: {data.shape}, dtype {data.dtype}") + print(f" Value range: [{data.min():.6f}, {data.max():.6f}], mean: {data.mean():.6f}") + +print("\n\n" + "=" * 100) +print("DATASET 2: osisaf_nh_sic (2018 ONLY DATA)") +print("=" * 100) +print(f"Location: {path2_base}") +print(f"Subdirectories: {os.listdir(path2_base)}") + +for split in ["train", "valid", "test"]: + split_path = os.path.join(path2_base, split, "data.pt") + if os.path.exists(split_path): + data = torch.load(split_path, map_location='cpu') + file_size_mb = os.path.getsize(split_path) / (1024**2) + + print(f"\n [{split}]") + print(f" File size: {file_size_mb:.2f} MB") + print(f" Type: {type(data)}") + + if isinstance(data, dict): + print(f" Keys: {list(data.keys())}") + for key, val in data.items(): + if isinstance(val, torch.Tensor): + print(f" {key}: shape {val.shape}, dtype {val.dtype}") + print(f" Value range: [{val.min():.6f}, {val.max():.6f}], mean: {val.mean():.6f}") + elif isinstance(val, list): + print(f" {key}: list of {len(val)} items") + elif isinstance(data, torch.Tensor): + print(f" Shape: {data.shape}, dtype {data.dtype}") + print(f" Value range: [{data.min():.6f}, {data.max():.6f}], mean: {data.mean():.6f}") + +print("\n\n" + "=" * 100) +print("KEY DIFFERENCES SUMMARY") +print("=" * 100) +print(""" +1. DATA SCOPE: + - osisaf_nh_sic_all: Multiple years of data (likely 1979-2023 or similar full record) + - osisaf_nh_sic: Only 2018 data + +2. STRUCTURE: + - Both should follow dict with 'data' key structure + - Main difference: time dimension size (multi-year vs single year) + +3. SPLIT DISTRIBUTION: + - osisaf_nh_sic_all: Large train/valid/test datasets (e.g., 70/15/15 split) + - osisaf_nh_sic: Smaller 2018 datasets (365 days split) + +4. USE CASE: + - osisaf_nh_sic_all: Full training dataset for ML models + - osisaf_nh_sic: Limited dataset, possibly for testing/debugging or year-specific analysis +""") + +EOF diff --git a/data.use_simulator=false b/data.use_simulator=false new file mode 100644 index 00000000..e69de29b diff --git a/debug_masked.sh b/debug_masked.sh new file mode 100644 index 00000000..e9f218c8 --- /dev/null +++ b/debug_masked.sh @@ -0,0 +1,31 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 1:00:00 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --tasks-per-node 4 +#SBATCH --job-name debug_masked + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +export HYDRA_FULL_ERROR=1 + +python -m autocast.scripts.train.encoder_processor_decoder \ + --config-path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast/configs \ + datamodule=osisaf_nh_sic \ + model.processor=masked_flow_matching \ + +model.processor.mask_path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/land_mask.pt \ + trainer.max_epochs=1 \ + trainer.accelerator=gpu \ + trainer.devices=1 diff --git a/download_osisaf.sh b/download_osisaf.sh new file mode 100644 index 00000000..f61f0f6f --- /dev/null +++ b/download_osisaf.sh @@ -0,0 +1,26 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 04:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=8 +#SBATCH --gpus=0 +#SBATCH --mem=64G +#SBATCH --job-name=download_osisaf_opendap +#SBATCH --output=logs/download_osisaf_opendap_%j.out +#SBATCH --error=logs/download_osisaf_opendap_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python download_osisaf_iceconc_only.py \ No newline at end of file diff --git a/download_osisaf_full.sh b/download_osisaf_full.sh new file mode 100644 index 00000000..1d0dcf63 --- /dev/null +++ b/download_osisaf_full.sh @@ -0,0 +1,115 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 03:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --gpus=0 +#SBATCH --mem=32G +#SBATCH --job-name download_osisaf +#SBATCH --output=logs/download_osisaf_%j.out +#SBATCH --error=logs/download_osisaf_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import xarray as xr +import os +import pandas as pd + +print("=" * 100) +print("DOWNLOADING OSI-SAF DATASET BY YEAR") +print("=" * 100) + +# OPENDAP URL for aggregated dataset +opendap_url = "https://thredds.met.no/thredds/dodsC/osisaf/met.no/reprocessed/ice/conc_450a1_nh_agg" +output_dir = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf" + +os.makedirs(output_dir, exist_ok=True) + +# Years to download (adjust range as needed) +# OSI-450-a1 covers 1979-2015 (CDR) +# For full coverage, you might want 1979-2023 or check available years +start_year = 1978 +end_year = 2020 + +print(f"\nOpening OPENDAP dataset: {opendap_url}") +print(f"Output directory: {output_dir}") +print(f"Downloading years: {start_year} to {end_year}") +print("=" * 100) + +try: + # Open the remote dataset (lazy loading) + print("\nConnecting to THREDDS server...") + ds = xr.open_dataset(opendap_url, engine='netcdf4', chunks={'time': 365}) + + print(f"✓ Connected successfully!") + print(f"\nDataset info:") + print(f" Dimensions: {dict(ds.dims)}") + print(f" Variables: {list(ds.data_vars)}") + print(f" Time range: {ds.time.values[0]} to {ds.time.values[-1]}") + + # Download year by year + for year in range(start_year, end_year + 1): + output_file = os.path.join(output_dir, f"osisaf_nh_{year}.nc") + + # Skip if already downloaded + if os.path.exists(output_file): + print(f"\n[{year}] File already exists, skipping: {output_file}") + continue + + print(f"\n[{year}] Selecting data for year {year}...") + + try: + # Select year's data + ds_year = ds.sel(time=str(year)) + + # Check if year has data + if len(ds_year.time) == 0: + print(f"[{year}] ⚠ No data available, skipping") + continue + + print(f"[{year}] Found {len(ds_year.time)} timesteps") + print(f"[{year}] Downloading...") + + # Download and save + ds_year.to_netcdf(output_file) + + file_size = os.path.getsize(output_file) / (1024**2) # MB + print(f"[{year}] ✓ Downloaded: {output_file} ({file_size:.1f} MB)") + + ds_year.close() + + except Exception as e: + print(f"[{year}] ✗ Error: {e}") + if os.path.exists(output_file): + os.remove(output_file) + continue + + print("\n" + "=" * 100) + print("DOWNLOAD COMPLETE") + print("=" * 100) + + # List all downloaded files + files = sorted([f for f in os.listdir(output_dir) if f.endswith('.nc')]) + print(f"\nTotal files downloaded: {len(files)}") + for f in files: + size = os.path.getsize(os.path.join(output_dir, f)) / (1024**2) + print(f" {f}: {size:.1f} MB") + + ds.close() + +except Exception as e: + print(f"\n✗ Fatal error: {e}") + raise + +EOF diff --git a/download_osisaf_iceconc_only.py b/download_osisaf_iceconc_only.py new file mode 100644 index 00000000..7ad71c5d --- /dev/null +++ b/download_osisaf_iceconc_only.py @@ -0,0 +1,47 @@ +"""Download individual OSI-SAF yearly files and combine them.""" + +import xarray as xr +import logging +from pathlib import Path +from urllib.request import urlopen +import os + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +log = logging.getLogger(__name__) + +output_dir = Path("/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw") +output_dir.mkdir(parents=True, exist_ok=True) + +# List of OSI-SAF files by year (example - you may need to adjust URLs) +base_url = "https://thredds.met.no/thredds/fileServer/osisaf/met.no/reprocessed/ice/" + +log.info("Downloading OSI-SAF yearly files...") +log.info("Note: This attempts to download individual files. If THREDDS is down, consider alternative sources.") + +# Try alternative: download just ice_conc variable to smaller file +opendap_url = "https://thredds.met.no/thredds/dodsC/osisaf/met.no/reprocessed/ice/conc_450a1_nh_agg" + +log.info(f"Attempting selective load from {opendap_url}") +try: + # Load ONLY the ice_conc variable (skip uncertainties) + ds = xr.open_dataset(opendap_url, engine='netcdf4') + ice_conc = ds[['ice_conc', 'time', 'lat', 'lon']].copy() + + log.info(f"Loaded ice_conc variable: {ice_conc['ice_conc'].shape}") + log.info(f"Size: {ice_conc.nbytes / 1e9:.2f} GB") + + output_file = output_dir / "raw_osisaf_nh_sic_all.nc" + log.info(f"Saving to {output_file}...") + + ice_conc.to_netcdf(output_file, engine='netcdf4', unlimited_dims=['time']) + + log.info(f"✓ Success! File saved: {output_file}") + log.info(f"Size: {output_file.stat().st_size / 1e9:.2f} GB") + +except Exception as e: + log.error(f"OPeNDAP failed: {e}") + log.info("THREDDS server appears to be unstable. Consider:") + log.info("1. Waiting and retrying later") + log.info("2. Downloading individual year files from: https://www.osi-saf.org/") + log.info("3. Using NSIDC or other mirror") + raise \ No newline at end of file diff --git a/download_osisaf_opendap.py b/download_osisaf_opendap.py new file mode 100644 index 00000000..a1419cf2 --- /dev/null +++ b/download_osisaf_opendap.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +"""Download OSI-SAF data via OPeNDAP.""" + +import xarray as xr +import logging +from pathlib import Path + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +log = logging.getLogger(__name__) + +# OPeNDAP endpoint +opendap_url = "https://thredds.met.no/thredds/dodsC/osisaf/met.no/reprocessed/ice/conc_450a1_nh_agg" + +output_path = Path("/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw/raw_osisaf_nh_sic_all.nc") +output_path.parent.mkdir(parents=True, exist_ok=True) + +log.info(f"Opening OPeNDAP dataset: {opendap_url}") +try: + # Open with xarray via OPeNDAP + log.info("Connecting to server...") + ds = xr.open_dataset(opendap_url, engine='netcdf4') + + log.info(f"Dataset loaded successfully!") + log.info(f"Dimensions: {dict(ds.dims)}") + log.info(f"Variables: {list(ds.data_vars)}") + log.info(f"Coordinates: {list(ds.coords)}") + + # Get size estimate + total_size = 0 + for var in ds.data_vars: + var_size = ds[var].nbytes / 1e9 + log.info(f" {var}: {var_size:.2f} GB") + total_size += var_size + + log.info(f"Total estimated size: {total_size:.2f} GB") + + # Save to netCDF + log.info(f"\nSaving to {output_path}...") + ds.to_netcdf(output_path, engine='netcdf4') + + final_size = output_path.stat().st_size / 1e9 + log.info(f"✓ Download complete!") + log.info(f"File size: {final_size:.2f} GB") + log.info(f"Location: {output_path}") + +except Exception as e: + log.error(f"Failed to download: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + raise diff --git a/evaluation_metrics.csv b/evaluation_metrics.csv new file mode 100644 index 00000000..f7195660 --- /dev/null +++ b/evaluation_metrics.csv @@ -0,0 +1,362 @@ +dataset_split,batch_index,num_samples,mse,rmse,vrmse +test,0,1,2.2954108715057373,1.515061378479004,0.051374468952417374 +test,1,1,2.9118940830230713,1.7064273357391357,0.057619113475084305 +test,2,1,2.7958152294158936,1.6720691919326782,0.05643235146999359 +test,3,1,2.654449701309204,1.6292481422424316,0.055268220603466034 +test,4,1,2.1709225177764893,1.473405122756958,0.05009395256638527 +test,5,1,2.7031805515289307,1.6441352367401123,0.05583777651190758 +test,6,1,2.46683931350708,1.570617437362671,0.053137317299842834 +test,7,1,2.1301631927490234,1.4595078229904175,0.049306273460388184 +test,8,1,2.1048049926757812,1.4507945775985718,0.04882647842168808 +test,9,1,2.11932635307312,1.455790638923645,0.048887088894844055 +test,10,1,2.400129795074463,1.549235224723816,0.05210896208882332 +test,11,1,1.9075790643692017,1.3811513185501099,0.04643389582633972 +test,12,1,1.916105031967163,1.3842344284057617,0.046338215470314026 +test,13,1,2.499119281768799,1.5808602571487427,0.05264642834663391 +test,14,1,2.1867587566375732,1.478769302368164,0.04920191690325737 +test,15,1,2.5200085639953613,1.5874534845352173,0.05275791883468628 +test,16,1,2.497490882873535,1.5803451538085938,0.052443474531173706 +test,17,1,3.4509170055389404,1.8576643466949463,0.061463724821805954 +test,18,1,3.1880874633789062,1.7855216264724731,0.05911692976951599 +test,19,1,3.687396764755249,1.9202595949172974,0.06359876692295074 +test,20,1,2.383582353591919,1.5438854694366455,0.05112234875559807 +test,21,1,1.873927354812622,1.3689146041870117,0.04515370726585388 +test,22,1,1.7629759311676025,1.3277710676193237,0.04356204718351364 +test,23,1,2.0926218032836914,1.446589708328247,0.047239888459444046 +test,24,1,2.4315032958984375,1.5593278408050537,0.05085981264710426 +test,25,1,3.1813626289367676,1.783637523651123,0.05824865400791168 +test,26,1,2.90346097946167,1.7039544582366943,0.05575438216328621 +test,27,1,3.0454962253570557,1.7451350688934326,0.057275693863630295 +test,28,1,3.887505054473877,1.9716757535934448,0.06502976268529892 +test,29,1,3.454895496368408,1.8587349653244019,0.061283793300390244 +test,30,1,3.1552960872650146,1.7763153314590454,0.05823983997106552 +test,31,1,2.62800931930542,1.621113657951355,0.052869513630867004 +test,32,1,2.719716787338257,1.6491563320159912,0.05364842712879181 +test,33,1,3.163675308227539,1.7786723375320435,0.05791918933391571 +test,34,1,3.0217838287353516,1.7383278608322144,0.05664012208580971 +test,35,1,3.8730809688568115,1.9680144786834717,0.06403011828660965 +test,36,1,4.209066867828369,2.051600933074951,0.06682595610618591 +test,37,1,2.887706756591797,1.699325442314148,0.055619169026613235 +test,38,1,3.226422071456909,1.7962243556976318,0.058951236307621 +test,39,1,3.2219398021698,1.7949762344360352,0.05889669805765152 +test,40,1,2.44561767578125,1.5638470649719238,0.051186319440603256 +test,41,1,2.659632921218872,1.63083815574646,0.053215641528367996 +test,42,1,2.3028204441070557,1.5175046920776367,0.04948476701974869 +test,43,1,2.9048562049865723,1.7043638229370117,0.05561501532793045 +test,44,1,3.1342484951019287,1.7703808546066284,0.05784883722662926 +test,45,1,3.091860771179199,1.758368730545044,0.05740339681506157 +test,46,1,2.841041088104248,1.6855387687683105,0.055000923573970795 +test,47,1,3.114622116088867,1.764829158782959,0.05739966779947281 +test,48,1,2.0661940574645996,1.437426209449768,0.04676426574587822 +test,49,1,2.2594852447509766,1.503158450126648,0.048918820917606354 +test,50,1,2.32588791847229,1.5250861644744873,0.04970494657754898 +test,51,1,2.5280234813690186,1.5899759531021118,0.051646072417497635 +test,52,1,2.4160773754119873,1.5543736219406128,0.0503070242702961 +test,53,1,3.05686092376709,1.7483880519866943,0.056341752409935 +test,54,1,4.028025150299072,2.0069940090179443,0.06464318186044693 +test,55,1,2.1850812435150146,1.478201985359192,0.04748241975903511 +test,56,1,3.0244226455688477,1.7390867471694946,0.05571410432457924 +test,57,1,2.3033013343811035,1.5176631212234497,0.04847929999232292 +test,58,1,3.222015142440796,1.794997215270996,0.057393915951251984 +test,59,1,2.67490816116333,1.6355146169662476,0.052432697266340256 +test,60,1,2.2068681716918945,1.4855531454086304,0.047684233635663986 +test,61,1,2.2847323417663574,1.5115331411361694,0.04851985350251198 +test,62,1,2.558305025100708,1.5994702577590942,0.05144129693508148 +test,63,1,3.5918099880218506,1.895207166671753,0.0609498955309391 +test,64,1,4.95081090927124,2.2250418663024902,0.07187968492507935 +test,65,1,3.9675188064575195,1.9918631315231323,0.06468874216079712 +test,66,1,4.6013407707214355,2.145073652267456,0.07019782811403275 +test,67,1,3.773808479309082,1.9426292181015015,0.06355676054954529 +test,68,1,2.792924642562866,1.6712045669555664,0.054547034204006195 +test,69,1,3.3873047828674316,1.8404631614685059,0.05996580421924591 +test,70,1,2.572648048400879,1.603947639465332,0.0521954745054245 +test,71,1,4.366348743438721,2.08958101272583,0.06819608807563782 +test,72,1,3.034885883331299,1.7420923709869385,0.05683807656168938 +test,73,1,3.450700283050537,1.857606053352356,0.06081685423851013 +test,74,1,3.9475831985473633,1.9868525266647339,0.06517399847507477 +test,75,1,2.8165154457092285,1.6782476902008057,0.055155687034130096 +test,76,1,3.0701072216033936,1.7521721124649048,0.05776086822152138 +test,77,1,2.2627298831939697,1.5042372941970825,0.04957323521375656 +test,78,1,2.6256625652313232,1.6203895807266235,0.05346110463142395 +test,79,1,2.0206785202026367,1.4215056896209717,0.04695673659443855 +test,80,1,2.7878189086914062,1.6696763038635254,0.05544239282608032 +test,81,1,2.8649959564208984,1.6926299333572388,0.05626785010099411 +test,82,1,3.3146286010742188,1.8206121921539307,0.060440145432949066 +test,83,1,3.0295112133026123,1.740549087524414,0.057721253484487534 +test,84,1,2.516153573989868,1.5862388610839844,0.052551936358213425 +test,85,1,2.564525842666626,1.6014137268066406,0.052993424236774445 +test,86,1,1.9018877744674683,1.3790894746780396,0.045798081904649734 +test,87,1,2.1262428760528564,1.4581642150878906,0.048525914549827576 +test,88,1,2.445920944213867,1.5639439821243286,0.05207239091396332 +test,89,1,2.4940025806427,1.579241156578064,0.05254611745476723 +test,90,1,2.1028060913085938,1.4501055479049683,0.04814058914780617 +test,91,1,2.260892152786255,1.5036263465881348,0.04975688084959984 +test,92,1,3.0891377925872803,1.757594347000122,0.05807933211326599 +test,93,1,2.3905234336853027,1.5461317300796509,0.0511934757232666 +test,94,1,1.7432063817977905,1.3203054666519165,0.04381454735994339 +test,95,1,1.7281447649002075,1.3145891427993774,0.043788958340883255 +test,96,1,1.7104318141937256,1.3078347444534302,0.04355235770344734 +test,97,1,2.17704439163208,1.4754810333251953,0.04912855103611946 +test,98,1,2.1910762786865234,1.4802284240722656,0.04933345690369606 +test,99,1,2.2448036670684814,1.4982669353485107,0.05010167881846428 +test,100,1,1.810872197151184,1.3456865549087524,0.045252569019794464 +test,101,1,1.6473206281661987,1.283479928970337,0.04332393407821655 +test,102,1,1.5080294609069824,1.2280185222625732,0.04154270514845848 +test,103,1,2.305997848510742,1.518551230430603,0.05149802565574646 +test,104,1,2.1508865356445312,1.4665900468826294,0.04971899464726448 +test,105,1,2.14375638961792,1.464157223701477,0.049746204167604446 +test,106,1,2.183213710784912,1.4775701761245728,0.050259076058864594 +test,107,1,1.9434000253677368,1.3940588235855103,0.047586824744939804 +test,108,1,2.3991003036499023,1.548902988433838,0.05322019010782242 +test,109,1,1.8172962665557861,1.3480713367462158,0.046444121748209 +test,110,1,1.3124101161956787,1.1456047296524048,0.039573196321725845 +test,111,1,1.8562862873077393,1.36245596408844,0.04729318246245384 +test,112,1,1.8009161949157715,1.3419822454452515,0.046454042196273804 +test,113,1,2.1952483654022217,1.4816370010375977,0.05124979466199875 +test,114,1,2.0062623023986816,1.4164259433746338,0.049140457063913345 +test,115,1,2.6650781631469727,1.6325067281723022,0.05654789134860039 +test,116,1,1.6451934576034546,1.2826509475708008,0.044344473630189896 +test,117,1,1.951865792274475,1.3970918655395508,0.04824506863951683 +test,118,1,1.6010217666625977,1.2653149366378784,0.04375791922211647 +test,119,1,1.875993251800537,1.3696690797805786,0.04749307781457901 +test,120,1,1.4228092432022095,1.1928156614303589,0.04148752987384796 +test,121,1,1.5227323770523071,1.2339904308319092,0.042995765805244446 +test,122,1,1.749778151512146,1.3227918148040771,0.04640747234225273 +test,123,1,1.8816224336624146,1.3717224597930908,0.048562876880168915 +test,124,1,2.0635859966278076,1.436518669128418,0.05107889696955681 +test,125,1,1.6518712043762207,1.2852513790130615,0.04572317749261856 +test,126,1,2.045933485031128,1.430361270904541,0.0511108823120594 +test,127,1,3.1149895191192627,1.7649333477020264,0.06363170593976974 +test,128,1,2.1022489070892334,1.4499133825302124,0.05237678065896034 +test,129,1,1.712256669998169,1.3085322380065918,0.04737139120697975 +test,130,1,1.2559350728988647,1.12068510055542,0.04063364490866661 +test,131,1,1.2858623266220093,1.1339586973190308,0.04121434688568115 +test,132,1,1.3391385078430176,1.1572115421295166,0.0422217883169651 +test,133,1,1.3663281202316284,1.1689003705978394,0.04289137199521065 +test,134,1,1.8348497152328491,1.354566216468811,0.05017821863293648 +test,135,1,2.810161828994751,1.6763536930084229,0.06227375939488411 +test,136,1,1.7028728723526,1.3049416542053223,0.04867463931441307 +test,137,1,1.959765076637268,1.3999160528182983,0.052464697510004044 +test,138,1,2.3854081630706787,1.5444766283035278,0.058490075170993805 +test,139,1,2.2807021141052246,1.5101993083953857,0.05718511343002319 +test,140,1,1.856248378753662,1.3624420166015625,0.051604289561510086 +test,141,1,2.0841269493103027,1.4436506032943726,0.055024437606334686 +test,142,1,1.8911964893341064,1.375207781791687,0.05241863802075386 +test,143,1,1.8861762285232544,1.37338125705719,0.0525202639400959 +test,144,1,1.5339298248291016,1.2385191917419434,0.04752679169178009 +test,145,1,1.5635592937469482,1.2504236698150635,0.04784224554896355 +test,146,1,1.7175310850143433,1.3105461597442627,0.04991820082068443 +test,147,1,1.4310272932052612,1.1962555646896362,0.04556933790445328 +test,148,1,1.2960678339004517,1.138449788093567,0.043547190725803375 +test,149,1,1.5549286603927612,1.2469677925109863,0.04792758822441101 +test,150,1,1.6031619310379028,1.2661603689193726,0.0487748384475708 +test,151,1,1.3660215139389038,1.1687692403793335,0.04509852081537247 +test,152,1,1.5751465559005737,1.255048394203186,0.048614710569381714 +test,153,1,1.4901491403579712,1.2207165956497192,0.04762700945138931 +test,154,1,2.4554760456085205,1.5669958591461182,0.061750221997499466 +test,155,1,1.6445389986038208,1.282395839691162,0.050615377724170685 +test,156,1,2.429178237915039,1.558582067489624,0.06203329563140869 +test,157,1,2.2034668922424316,1.484407901763916,0.05971900001168251 +test,158,1,1.8947087526321411,1.3764841556549072,0.055944398045539856 +test,159,1,2.092768669128418,1.4466404914855957,0.05946389213204384 +test,160,1,2.1448285579681396,1.4645233154296875,0.0609576553106308 +test,161,1,2.1002249717712402,1.449215292930603,0.06120578572154045 +test,162,1,2.3297977447509766,1.5263675451278687,0.06550471484661102 +test,163,1,2.4758219718933105,1.5734745264053345,0.06808462738990784 +test,164,1,2.4648289680480957,1.5699774026870728,0.06877274811267853 +test,165,1,1.8305891752243042,1.3529926538467407,0.05935506895184517 +test,166,1,1.782676339149475,1.3351690769195557,0.05825292691588402 +test,167,1,2.1087255477905273,1.4521450996398926,0.0631764754652977 +test,168,1,1.9950220584869385,1.4124524593353271,0.061956360936164856 +test,169,1,1.5521081686019897,1.2458363771438599,0.05487050488591194 +test,170,1,1.674618124961853,1.2940703630447388,0.057528648525476456 +test,171,1,2.0230376720428467,1.4223352670669556,0.06381532549858093 +test,172,1,2.1824817657470703,1.4773224592208862,0.06668535619974136 +test,173,1,1.981589436531067,1.4076894521713257,0.06328076124191284 +test,174,1,2.4803097248077393,1.5748999118804932,0.07114588469266891 +test,175,1,2.754852771759033,1.659774899482727,0.07584216445684433 +test,176,1,3.2689461708068848,1.8080227375030518,0.08359692245721817 +test,177,1,3.028475761413574,1.7402516603469849,0.08114993572235107 +test,178,1,2.6182775497436523,1.6181092262268066,0.07682832330465317 +test,179,1,2.1291918754577637,1.4591751098632812,0.07051297277212143 +test,180,1,1.8528506755828857,1.3611946105957031,0.0659262016415596 +test,181,1,1.638350009918213,1.2799804210662842,0.062153588980436325 +test,182,1,1.7872720956802368,1.3368889093399048,0.06505837291479111 +test,183,1,1.5985397100448608,1.264333724975586,0.06145985424518585 +test,184,1,1.6144167184829712,1.270596981048584,0.06135650724172592 +test,185,1,1.7196977138519287,1.3113723993301392,0.0635027140378952 +test,186,1,1.8805797100067139,1.3713423013687134,0.06698291003704071 +test,187,1,2.117208242416382,1.455062985420227,0.07158409804105759 +test,188,1,1.9141863584518433,1.383541226387024,0.06833413988351822 +test,189,1,1.9296875,1.3891319036483765,0.06909990310668945 +test,190,1,2.188912868499756,1.4794975519180298,0.07423880696296692 +test,191,1,2.2581565380096436,1.5027164220809937,0.07623597979545593 +test,192,1,1.9116575717926025,1.382627010345459,0.07123709470033646 +test,193,1,1.4618622064590454,1.2090749740600586,0.06315236538648605 +test,194,1,1.8380074501037598,1.3557313680648804,0.07183540612459183 +test,195,1,1.6989803314208984,1.3034493923187256,0.06989964097738266 +test,196,1,1.716986894607544,1.3103384971618652,0.07046730071306229 +test,197,1,1.327892780303955,1.1523423194885254,0.06289245188236237 +test,198,1,1.2581671476364136,1.121680498123169,0.06203201040625572 +test,199,1,1.6906894445419312,1.3002651929855347,0.07236552238464355 +test,200,1,2.047973871231079,1.4310743808746338,0.08068658411502838 +test,201,1,2.0675742626190186,1.437906265258789,0.08307871222496033 +test,202,1,1.7132277488708496,1.3089032173156738,0.07828383892774582 +test,203,1,1.6269276142120361,1.2755106687545776,0.0783715769648552 +test,204,1,1.351640224456787,1.1626006364822388,0.07270434498786926 +test,205,1,1.7876601219177246,1.3370341062545776,0.08547443896532059 +test,206,1,0.9818444848060608,0.9908806681632996,0.06490179151296616 +test,207,1,1.8254519701004028,1.3510929346084595,0.09071721136569977 +test,208,1,1.5269063711166382,1.2356805801391602,0.08189494907855988 +test,209,1,1.325319766998291,1.1512253284454346,0.0761263370513916 +test,210,1,1.57725989818573,1.255890130996704,0.08472301810979843 +test,211,1,1.2788721323013306,1.1308722496032715,0.07715948671102524 +test,212,1,1.45255446434021,1.2052196264266968,0.08274233341217041 +test,213,1,1.649444580078125,1.2843070030212402,0.08832326531410217 +test,214,1,1.7838354110717773,1.3356029987335205,0.09174840152263641 +test,215,1,1.7643964290618896,1.3283058404922485,0.09106408804655075 +test,216,1,1.982025146484375,1.4078441858291626,0.0967753455042839 +test,217,1,1.762534737586975,1.32760488986969,0.09234175831079483 +test,218,1,1.7910370826721191,1.3382962942123413,0.09466087073087692 +test,219,1,1.370544672012329,1.1707026958465576,0.08388327807188034 +test,220,1,2.367856025695801,1.5387839078903198,0.10871592164039612 +test,221,1,1.7140567302703857,1.3092198371887207,0.09340906888246536 +test,222,1,1.5275001525878906,1.235920786857605,0.08842483907938004 +test,223,1,1.3597664833068848,1.1660902500152588,0.08293834328651428 +test,224,1,1.152630090713501,1.0736061334609985,0.07745048403739929 +test,225,1,1.2548942565917969,1.12022066116333,0.08124080300331116 +test,226,1,1.3487449884414673,1.1613547801971436,0.08372946828603745 +test,227,1,1.3814704418182373,1.1753597259521484,0.08483397215604782 +test,228,1,1.0794646739959717,1.0389728546142578,0.07546224445104599 +test,229,1,1.3935942649841309,1.1805059909820557,0.08719667792320251 +test,230,1,1.6206634044647217,1.2730528116226196,0.09335753321647644 +test,231,1,1.3482630252838135,1.1611472368240356,0.0837354063987732 +test,232,1,1.921094536781311,1.3860355615615845,0.10086637735366821 +test,233,1,2.5474231243133545,1.5960649251937866,0.11613990366458893 +test,234,1,2.0415751934051514,1.4288370609283447,0.1043892353773117 +test,235,1,1.7437459230422974,1.320509672164917,0.09857044368982315 +test,236,1,1.5432124137878418,1.2422610521316528,0.09181026369333267 +test,237,1,1.8031747341156006,1.3428233861923218,0.09645778685808182 +test,238,1,1.5992175340652466,1.264601707458496,0.09043888002634048 +test,239,1,0.8992719054222107,0.9482994675636292,0.06765250116586685 +test,240,1,0.8416992425918579,0.9174416661262512,0.06539944559335709 +test,241,1,0.8171206712722778,0.9039472937583923,0.06394126266241074 +test,242,1,1.6371183395385742,1.2794992923736572,0.09028323739767075 +test,243,1,1.5750038623809814,1.2549915313720703,0.08950062096118927 +test,244,1,1.2372080087661743,1.1122984886169434,0.07939878851175308 +test,245,1,1.2590235471725464,1.1220622062683105,0.07871098816394806 +test,246,1,1.545353889465332,1.2431225776672363,0.08424890786409378 +test,247,1,1.589030146598816,1.2605674266815186,0.08514686673879623 +test,248,1,1.0662267208099365,1.0325825214385986,0.06921332329511642 +test,249,1,0.9980116486549377,0.9990053176879883,0.0671616718173027 +test,250,1,0.8501986861228943,0.9220622181892395,0.061714429408311844 +test,251,1,1.2765445709228516,1.129842758178711,0.07654483616352081 +test,252,1,1.6556521654129028,1.2867214679718018,0.08732561767101288 +test,253,1,1.2254092693328857,1.1069821119308472,0.07467290759086609 +test,254,1,2.843947649002075,1.6864007711410522,0.11132682859897614 +test,255,1,1.7161669731140137,1.3100255727767944,0.08603314310312271 +test,256,1,2.2099690437316895,1.4865964651107788,0.09768711775541306 +test,257,1,1.4082307815551758,1.186689019203186,0.07738002389669418 +test,258,1,1.4690607786178589,1.2120481729507446,0.07808894664049149 +test,259,1,1.9737098217010498,1.4048877954483032,0.09038293361663818 +test,260,1,1.643810749053955,1.2821118831634521,0.08291161805391312 +test,261,1,1.7332544326782227,1.3165311813354492,0.08507753163576126 +test,262,1,2.5525155067443848,1.5976593494415283,0.10293659567832947 +test,263,1,2.202458620071411,1.4840682744979858,0.0947738066315651 +test,264,1,1.285496473312378,1.1337974071502686,0.07219336181879044 +test,265,1,1.364380955696106,1.168067216873169,0.07455698400735855 +test,266,1,1.2788596153259277,1.1308667659759521,0.07174985855817795 +test,267,1,1.51105797290802,1.2292510271072388,0.07776337116956711 +test,268,1,2.6466763019561768,1.6268608570098877,0.1020447388291359 +test,269,1,1.8768099546432495,1.3699671030044556,0.08631770312786102 +test,270,1,1.475661277770996,1.2147680521011353,0.07586326450109482 +test,271,1,1.4846807718276978,1.2184747457504272,0.07558241486549377 +test,272,1,1.0059425830841064,1.0029668807983398,0.06181968003511429 +test,273,1,0.978197455406189,0.9890386462211609,0.06044771149754524 +test,274,1,1.0927503108978271,1.045346975326538,0.0630945935845375 +test,275,1,1.2021404504776,1.0964215993881226,0.0651240199804306 +test,276,1,1.1451448202133179,1.0701143741607666,0.062842458486557 +test,277,1,1.313506841659546,1.1460832357406616,0.06634017080068588 +test,278,1,1.0583326816558838,1.0287529230117798,0.059206265956163406 +test,279,1,1.3496119976043701,1.161728024482727,0.0666448324918747 +test,280,1,1.5890319347381592,1.2605681419372559,0.07139711081981659 +test,281,1,1.7662246227264404,1.328993797302246,0.07463937252759933 +test,282,1,1.5561213493347168,1.2474459409713745,0.0697803795337677 +test,283,1,1.3424038887023926,1.1586215496063232,0.06574226170778275 +test,284,1,1.388954997062683,1.1785393953323364,0.06629349291324615 +test,285,1,1.161298394203186,1.0776355266571045,0.059474773705005646 +test,286,1,1.0260084867477417,1.0129207372665405,0.05525468662381172 +test,287,1,1.124598503112793,1.0604709386825562,0.057800356298685074 +test,288,1,1.279191255569458,1.1310133934020996,0.06100854277610779 +test,289,1,1.6214797496795654,1.2733733654022217,0.06848672032356262 +test,290,1,1.787825345993042,1.337095856666565,0.0715288296341896 +test,291,1,1.8459923267364502,1.3586729764938354,0.07146041095256805 +test,292,1,2.474635124206543,1.5730973482131958,0.08188294619321823 +test,293,1,2.230742931365967,1.4935672283172607,0.07660026103258133 +test,294,1,2.6638050079345703,1.632116675376892,0.08264830708503723 +test,295,1,2.222914218902588,1.4909440279006958,0.07486485689878464 +test,296,1,3.477806568145752,1.8648878335952759,0.09232683479785919 +test,297,1,4.146582126617432,2.03631591796875,0.0988449975848198 +test,298,1,2.423983097076416,1.5569145679473877,0.07455474138259888 +test,299,1,3.171497344970703,1.7808698415756226,0.08383973687887192 +test,300,1,2.452932357788086,1.5661840438842773,0.07251805067062378 +test,301,1,1.9508548974990845,1.3967300653457642,0.06359033286571503 +test,302,1,2.5594472885131836,1.5998272895812988,0.07229283452033997 +test,303,1,2.3105432987213135,1.5200471878051758,0.0683099776506424 +test,304,1,2.2559220790863037,1.5019726753234863,0.06708943098783493 +test,305,1,2.3303215503692627,1.5265390872955322,0.06733492016792297 +test,306,1,2.2700438499450684,1.5066664218902588,0.06536324322223663 +test,307,1,2.847651481628418,1.6874985694885254,0.07225848734378815 +test,308,1,2.0922274589538574,1.4464534521102905,0.061770033091306686 +test,309,1,2.785921812057495,1.6691080331802368,0.07079093903303146 +test,310,1,2.042754650115967,1.42924964427948,0.05999332293868065 +test,311,1,3.0779519081115723,1.7544093132019043,0.07303162664175034 +test,312,1,4.0506815910339355,2.0126304626464844,0.0841449722647667 +test,313,1,4.456706523895264,2.111091375350952,0.08796893060207367 +test,314,1,3.2296080589294434,1.7971110343933105,0.07450201362371445 +test,315,1,2.8332719802856445,1.6832325458526611,0.06981714814901352 +test,316,1,2.198408365249634,1.4827030897140503,0.06130814179778099 +test,317,1,1.924912691116333,1.387412190437317,0.05697503685951233 +test,318,1,2.332709312438965,1.5273209810256958,0.06205711141228676 +test,319,1,2.0302882194519043,1.4248818159103394,0.057495612651109695 +test,320,1,2.644354820251465,1.6261472702026367,0.06473422795534134 +test,321,1,2.3039162158966064,1.5178656578063965,0.05950102210044861 +test,322,1,2.2876837253570557,1.5125091075897217,0.05897175893187523 +test,323,1,2.7445638179779053,1.656672477722168,0.06462698429822922 +test,324,1,2.2936291694641113,1.5144731998443604,0.059146929532289505 +test,325,1,3.1641287803649902,1.7787997722625732,0.06933474540710449 +test,326,1,3.187115430831909,1.7852493524551392,0.06938684731721878 +test,327,1,2.1182260513305664,1.4554126262664795,0.056479308754205704 +test,328,1,2.1620326042175293,1.470385193824768,0.056994300335645676 +test,329,1,4.707685947418213,2.16972017288208,0.08377216756343842 +test,330,1,2.6000726222991943,1.6124740839004517,0.06199172884225845 +test,331,1,3.453089475631714,1.8582490682601929,0.07139220833778381 +test,332,1,3.1603190898895264,1.7777286767959595,0.06810157746076584 +test,333,1,2.7749149799346924,1.6658076047897339,0.06319059431552887 +test,334,1,2.7552425861358643,1.6598923206329346,0.062120221555233 +test,335,1,2.2248802185058594,1.4916032552719116,0.055360645055770874 +test,336,1,2.340345859527588,1.5298188924789429,0.05648922547698021 +test,337,1,2.6135334968566895,1.616642713546753,0.05945281311869621 +test,338,1,2.908634901046753,1.7054719924926758,0.06262391805648804 +test,339,1,2.7268693447113037,1.6513235569000244,0.06067650020122528 +test,340,1,3.351083278656006,1.8305964469909668,0.06716714799404144 +test,341,1,3.2705419063568115,1.8084639310836792,0.06609812378883362 +test,342,1,3.3554346561431885,1.8317846059799194,0.06660671532154083 +test,343,1,2.9759652614593506,1.7250986099243164,0.06263277679681778 +test,344,1,2.8759210109710693,1.6958540678024292,0.06126667931675911 +test,345,1,2.7232494354248047,1.6502270698547363,0.05911175161600113 +test,346,1,1.9141619205474854,1.3835324048995972,0.04927384480834007 +test,347,1,2.0608177185058594,1.4355548620224,0.05113440379500389 +test,348,1,2.8294174671173096,1.6820871829986572,0.05937372148036957 +test,349,1,1.7836834192276,1.3355461359024048,0.046906132251024246 +test,350,1,2.437979221343994,1.5614029169082642,0.054769545793533325 +test,351,1,2.4132819175720215,1.5534741878509521,0.05447549745440483 +test,352,1,2.6002845764160156,1.6125397682189941,0.05642378330230713 +test,353,1,2.885568618774414,1.6986961364746094,0.05913061648607254 +test,354,1,2.5901105403900146,1.6093820333480835,0.05601368099451065 +test,355,1,3.7603085041046143,1.9391515254974365,0.06778029352426529 +test,356,1,2.9537250995635986,1.7186404466629028,0.060023702681064606 +test,357,1,1.6534228324890137,1.2858549356460571,0.04477742686867714 +test,358,1,1.7027961015701294,1.3049123287200928,0.04539882019162178 +test,359,1,2.1672489643096924,1.4721579551696777,0.051130522042512894 +test,all,360,2.200034694870313,1.462959815065066,0.06366727451483409 diff --git a/fill_nan_osisaf.sh b/fill_nan_osisaf.sh new file mode 100644 index 00000000..6324f8f2 --- /dev/null +++ b/fill_nan_osisaf.sh @@ -0,0 +1,97 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 00:30:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --gpus=0 +#SBATCH --mem=64G +#SBATCH --job-name fill_nan_osisaf +#SBATCH --output=logs/fill_nan_osisaf_%j.out +#SBATCH --error=logs/fill_nan_osisaf_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +# Activate conda environment +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import torch +import os +from pathlib import Path + +base_path = Path("/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all") + +print("=" * 100) +print("FILLING NaN VALUES WITH 0 IN osisaf_nh_sic_all DATASET") +print("=" * 100) + +splits = ["train", "valid", "test"] + +for split in splits: + split_path = base_path / split / "data.pt" + + print(f"\n{'=' * 100}") + print(f"Processing: {split}") + print(f"{'=' * 100}") + print(f"Path: {split_path}") + + # Load the data + print(f"Loading data...") + data_dict = torch.load(split_path, map_location='cpu') + data_tensor = data_dict["data"] + + print(f"Shape: {data_tensor.shape}") + print(f"Dtype: {data_tensor.dtype}") + + # Check NaN statistics + num_nans_before = torch.isnan(data_tensor).sum().item() + total_elements = data_tensor.numel() + nan_percentage_before = (num_nans_before / total_elements) * 100 + + print(f"\nBefore filling:") + print(f" Total elements: {total_elements:,}") + print(f" NaN count: {num_nans_before:,}") + print(f" NaN percentage: {nan_percentage_before:.4f}%") + if num_nans_before > 0: + print(f" Min (ignoring NaN): {data_tensor[~torch.isnan(data_tensor)].min():.6f}") + print(f" Max (ignoring NaN): {data_tensor[~torch.isnan(data_tensor)].max():.6f}") + print(f" Mean (ignoring NaN): {data_tensor[~torch.isnan(data_tensor)].mean():.6f}") + + # Fill NaN with 0 + print(f"\nFilling NaN values with 0...") + data_tensor = torch.nan_to_num(data_tensor, nan=0.0) + data_dict["data"] = data_tensor + + # Check after filling + num_nans_after = torch.isnan(data_tensor).sum().item() + print(f"\nAfter filling:") + print(f" NaN count: {num_nans_after}") + print(f" Min: {data_tensor.min():.6f}") + print(f" Max: {data_tensor.max():.6f}") + print(f" Mean: {data_tensor.mean():.6f}") + + # Save back + print(f"\nSaving modified data...") + torch.save(data_dict, split_path) + print(f"✓ Saved to {split_path}") + + # Verify the save + print(f"Verifying save...") + verify_dict = torch.load(split_path, map_location='cpu') + verify_tensor = verify_dict["data"] + verify_nans = torch.isnan(verify_tensor).sum().item() + print(f"✓ Verification: NaN count after reload = {verify_nans}") + +print(f"\n\n{'=' * 100}") +print("✓ COMPLETED: All NaN values in osisaf_nh_sic_all have been replaced with 0") +print(f"{'=' * 100}") + +EOF diff --git a/fill_nan_simple.sh b/fill_nan_simple.sh new file mode 100644 index 00000000..04ec44b7 --- /dev/null +++ b/fill_nan_simple.sh @@ -0,0 +1,60 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 01:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=8 +#SBATCH --gpus=0 +#SBATCH --mem=64G +#SBATCH --job-name fill_nan_osisaf +#SBATCH --output=logs/fill_nan_osisaf_%j.out +#SBATCH --error=logs/fill_nan_osisaf_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import torch +import os + +print("Filling NaN values with 0 in osisaf_nh_sic_all dataset\n") + +base_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all" + +for split in ["train", "valid", "test"]: + split_path = os.path.join(base_path, split, "data.pt") + + print(f"Processing {split}...") + + # Load + data_dict = torch.load(split_path) + data_tensor = data_dict["data"] + + # Count NaN before + nan_before = torch.isnan(data_tensor).sum().item() + + # Fill NaN with 0 + data_tensor[torch.isnan(data_tensor)] = 0.0 + + # Count NaN after + nan_after = torch.isnan(data_tensor).sum().item() + + # Save + torch.save(data_dict, split_path) + + print(f" Shape: {data_tensor.shape}") + print(f" NaN before: {nan_before:,}") + print(f" NaN after: {nan_after:,}") + print(f" ✓ Saved\n") + +print("Done!") + +EOF diff --git a/fill_nan_v2.sh b/fill_nan_v2.sh new file mode 100644 index 00000000..a6a49ab4 --- /dev/null +++ b/fill_nan_v2.sh @@ -0,0 +1,107 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 01:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=8 +#SBATCH --gpus=0 +#SBATCH --mem=64G +#SBATCH --job-name fill_nan_v2 +#SBATCH --output=logs/fill_nan_v2_%j.out +#SBATCH --error=logs/fill_nan_v2_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +# Activate conda environment +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import torch +import os + +print("=" * 100) +print("FILLING NaN VALUES WITH 0 IN osisaf_nh_sic_all DATASET") +print("=" * 100) + +base_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all" + +splits = ["train", "valid", "test"] + +for split in splits: + split_path = os.path.join(base_path, split, "data.pt") + + print(f"\n{'=' * 100}") + print(f"Processing: {split}") + print(f"Path: {split_path}") + print(f"{'=' * 100}") + + # Load the data + print("Loading data...") + data_dict = torch.load(split_path) + data_tensor = data_dict["data"] + + print(f"Shape: {data_tensor.shape}") + print(f"Dtype: {data_tensor.dtype}") + + # Count NaN before + nan_count_before = torch.isnan(data_tensor).sum().item() + total_elements = data_tensor.numel() + nan_pct_before = (nan_count_before / total_elements) * 100 + + print(f"\nBefore filling:") + print(f" Total elements: {total_elements:,}") + print(f" NaN count: {nan_count_before:,}") + print(f" NaN percentage: {nan_pct_before:.4f}%") + + # Fill NaN with 0 IN PLACE + print("\nFilling NaN with 0...") + data_tensor[torch.isnan(data_tensor)] = 0.0 + + # Count NaN after + nan_count_after = torch.isnan(data_tensor).sum().item() + nan_pct_after = (nan_count_after / total_elements) * 100 + + print(f"\nAfter filling:") + print(f" NaN count: {nan_count_after:,}") + print(f" NaN percentage: {nan_pct_after:.4f}%") + + # Show value statistics (only if not all zeros) + non_zero = (data_tensor != 0).sum().item() + print(f" Non-zero elements: {non_zero:,}") + + if non_zero > 0: + min_val = data_tensor[data_tensor != 0].min().item() + max_val = data_tensor[data_tensor != 0].max().item() + mean_val = data_tensor[data_tensor != 0].mean().item() + print(f" Value range (non-zero): [{min_val:.6f}, {max_val:.6f}]") + print(f" Mean (non-zero): {mean_val:.6f}") + + # Save back + print(f"\nSaving back to {split_path}...") + torch.save(data_dict, split_path) + + # Verify by reloading + print("Verifying by reloading...") + verify_dict = torch.load(split_path) + verify_tensor = verify_dict["data"] + verify_nan_count = torch.isnan(verify_tensor).sum().item() + + if verify_nan_count == 0: + print("✓ Verification PASSED - No NaN values in saved file") + else: + print(f"✗ Verification FAILED - Still has {verify_nan_count} NaN values!") + raise RuntimeError(f"Verification failed for {split}") + +print("\n" + "=" * 100) +print("COMPLETED SUCCESSFULLY") +print("=" * 100) +print("All NaN values in osisaf_nh_sic_all dataset have been filled with 0") + +EOF diff --git a/import b/import new file mode 100644 index 00000000..e69de29b diff --git a/inspect_datasets.sh b/inspect_datasets.sh new file mode 100644 index 00000000..a5e26a2d --- /dev/null +++ b/inspect_datasets.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 00:30:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --gpus=0 +#SBATCH --mem=32G +#SBATCH --job-name inspect_datasets +#SBATCH --output=logs/inspect_datasets_%j.out +#SBATCH --error=logs/inspect_datasets_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +# Activate conda environment +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import xarray as xr +import torch +import numpy as np + +print("=" * 80) +print("RAW OSI-SAF DATASET (netCDF)") +print("=" * 80) + +# Load raw OSI-SAF 2018 netCDF file +raw_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/osisaf_nh_2018.nc" +raw_ds = xr.open_dataset(raw_path) + +print(f"\nType: {type(raw_ds)}") +print(f"\nDimensions: {dict(raw_ds.dims)}") +print(f"\nCoordinates: {list(raw_ds.coords)}") +for coord in raw_ds.coords: + print(f" {coord}: {raw_ds.coords[coord].shape}") +print(f"\nData Variables: {list(raw_ds.data_vars)}") +print(f"\nShape of each variable:") +for var in raw_ds.data_vars: + print(f" {var}: {raw_ds[var].shape}") +print(f"\nFull xarray structure:") +print(raw_ds) + +print("\n\n" + "=" * 80) +print("PROCESSED DATASET (PyTorch .pt)") +print("=" * 80) + +# Load processed dataset +processed_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/train/data.pt" +processed_data = torch.load(processed_path) + +print(f"\nType: {type(processed_data)}") +if isinstance(processed_data, dict): + print(f"Keys: {list(processed_data.keys())}") + print(f"\nDetailed structure:") + for key, val in processed_data.items(): + if isinstance(val, torch.Tensor): + print(f" {key}:") + print(f" Type: Tensor") + print(f" Shape: {val.shape}") + print(f" Dtype: {val.dtype}") + print(f" Min: {val.min():.4f}, Max: {val.max():.4f}, Mean: {val.mean():.4f}") + elif isinstance(val, list): + print(f" {key}: List of {len(val)} items") + if len(val) > 0: + print(f" First item type: {type(val[0])}") + if isinstance(val[0], torch.Tensor): + print(f" First item shape: {val[0].shape}") + elif isinstance(val, np.ndarray): + print(f" {key}:") + print(f" Type: numpy array") + print(f" Shape: {val.shape}") + print(f" Dtype: {val.dtype}") + else: + print(f" {key}: {type(val)}") + if not callable(val): + print(f" Value: {val}") +elif isinstance(processed_data, torch.Tensor): + print(f"Single Tensor:") + print(f" Shape: {processed_data.shape}") + print(f" Dtype: {processed_data.dtype}") + print(f" Min: {processed_data.min():.4f}, Max: {processed_data.max():.4f}") +elif isinstance(processed_data, list): + print(f"List of {len(processed_data)} items") + print(f"First item type: {type(processed_data[0])}") + +print("\n" + "=" * 80) +print("COMPARISON & ANALYSIS") +print("=" * 80) +print("\nKey differences:") +print("1. RAW FORMAT: NetCDF (xarray.Dataset)") +print(" - Hierarchical structure with dimensions, coordinates, and variables") +print(" - Includes full metadata and attributes") +print(" - Human-readable format") +print("\n2. PROCESSED FORMAT: PyTorch (.pt)") +print(" - Dictionary structure for efficient ML training") +print(" - Optimized for GPU loading and batching") +print(" - Pre-normalized, pre-processed tensors") + +EOF diff --git a/inspect_filled.sh b/inspect_filled.sh new file mode 100644 index 00000000..ea805d2a --- /dev/null +++ b/inspect_filled.sh @@ -0,0 +1,76 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 00:15:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --gpus=0 +#SBATCH --mem=64G +#SBATCH --job-name inspect_filled +#SBATCH --output=logs/inspect_filled_%j.out +#SBATCH --error=logs/inspect_filled_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import torch +import numpy as np + +print("=" * 100) +print("INSPECTING FILLED MULTI-YEAR DATASET") +print("=" * 100) + +base_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all" + +for split in ["train", "valid", "test"]: + path = f"{base_path}/{split}/data.pt" + data_dict = torch.load(path) + data = data_dict["data"] + + print(f"\n{split.upper()}:") + print(f" Shape: {data.shape}") + print(f" Dtype: {data.dtype}") + print(f" Min: {data.min():.6f}") + print(f" Max: {data.max():.6f}") + print(f" Mean: {data.mean():.6f}") + print(f" Median: {torch.median(data):.6f}") + print(f" Std: {data.std():.6f}") + print(f" NaN count: {torch.isnan(data).sum()}") + + # Check histogram + non_zero = data[data != 0].numel() + zero = (data == 0).sum().item() + total = data.numel() + + print(f" Zero values: {zero:,} ({100*zero/total:.1f}%)") + print(f" Non-zero values: {non_zero:,} ({100*non_zero/total:.1f}%)") + + if non_zero > 0: + non_zero_data = data[data != 0] + print(f" Non-zero range: [{non_zero_data.min():.6f}, {non_zero_data.max():.6f}]") + print(f" Non-zero mean: {non_zero_data.mean():.6f}") + +print("\n" + "=" * 100) +print("COMPARISON WITH 2018 DATASET") +print("=" * 100) + +path_2018 = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic/train/data.pt" +data_2018 = torch.load(path_2018)["data"] + +print(f"\n2018 TRAIN:") +print(f" Shape: {data_2018.shape}") +print(f" Min: {data_2018.min():.6f}") +print(f" Max: {data_2018.max():.6f}") +print(f" Mean: {data_2018.mean():.6f}") +print(f" Non-zero: {(data_2018 != 0).sum().item():,} ({100*(data_2018 != 0).sum().item()/data_2018.numel():.1f}%)") + +EOF diff --git a/inspect_osisaf.sh b/inspect_osisaf.sh new file mode 100644 index 00000000..ca57ac7f --- /dev/null +++ b/inspect_osisaf.sh @@ -0,0 +1,39 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 0:10:00 +#SBATCH --nodes 1 +#SBATCH --gpus 0 +#SBATCH --tasks-per-node 1 +#SBATCH --job-name inspect_osisaf + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import torch + +# Check what's in one of the data files +train_file = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/train/data.pt" +print(f"Loading: {train_file}") +data = torch.load(train_file, map_location='cpu') + +print(f"\nType of loaded data: {type(data)}") +print(f"Data shape: {data.shape if hasattr(data, 'shape') else 'N/A'}") +print(f"Data dtype: {data.dtype if hasattr(data, 'dtype') else 'N/A'}") + +if isinstance(data, dict): + print(f"\nIt's a dict with keys: {data.keys()}") + for k, v in data.items(): + print(f" {k}: {type(v)} shape={v.shape if hasattr(v, 'shape') else 'N/A'}") +else: + print(f"\n❌ It's a raw {type(data).__name__}, not a dict!") + print("This is why the dataset code fails.") +EOF diff --git a/inspect_raw_multiyr.sh b/inspect_raw_multiyr.sh new file mode 100644 index 00000000..1a000f85 --- /dev/null +++ b/inspect_raw_multiyr.sh @@ -0,0 +1,60 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 00:10:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=2 +#SBATCH --gpus=0 +#SBATCH --mem=16G +#SBATCH --job-name inspect_raw +#SBATCH --output=logs/inspect_raw_%j.out +#SBATCH --error=logs/inspect_raw_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import xarray as xr + +print("=" * 100) +print("INSPECTING RAW MULTI-YEAR DATASET") +print("=" * 100) + +raw_file = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw/osisaf_nh_sic_reprocessed.nc" + +try: + ds = xr.open_dataset(raw_file) + print(f"\nFile: {raw_file}") + print(f"\nFull structure:") + print(ds) + + print(f"\n\nTime range:") + if 'time' in ds.coords: + times = ds.coords['time'].values + print(f" First time: {times[0]}") + print(f" Last time: {times[-1]}") + print(f" Total timesteps: {len(times)}") + + print(f"\n\nVariables and their stats:") + for var in ds.data_vars: + data = ds[var] + print(f"\n {var}:") + print(f" Shape: {data.shape}") + print(f" Min: {float(data.min()):.6f}") + print(f" Max: {float(data.max()):.6f}") + print(f" Mean: {float(data.mean()):.6f}") + +except Exception as e: + print(f"Error: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + +EOF diff --git a/investigate_structure.sh b/investigate_structure.sh new file mode 100644 index 00000000..f2cf9121 --- /dev/null +++ b/investigate_structure.sh @@ -0,0 +1,101 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 00:30:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --gpus=0 +#SBATCH --mem=32G +#SBATCH --job-name investigate_structure +#SBATCH --output=logs/investigate_structure_%j.out +#SBATCH --error=logs/investigate_structure_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +# Activate conda environment +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import torch +import sys + +print("=" * 100) +print("INVESTIGATING DATASET LOADING DIFFERENCES") +print("=" * 100) + +datasets = [ + ("osisaf_nh_sic_all", "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/train/data.pt"), + ("osisaf_nh_sic", "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic/train/data.pt"), +] + +for name, path in datasets: + print(f"\n{'=' * 100}") + print(f"DATASET: {name}") + print(f"{'=' * 100}") + print(f"Path: {path}") + + try: + # Load the data + raw_data = torch.load(path, map_location='cpu') + + print(f"\n1. WHAT IS LOADED FROM torch.load():") + print(f" Type: {type(raw_data)}") + print(f" Type name: {type(raw_data).__name__}") + + if isinstance(raw_data, dict): + print(f" ✓ It's a Dictionary!") + print(f" Keys: {list(raw_data.keys())}") + + # Check if "data" key exists + if "data" in raw_data: + print(f" ✓ 'data' key EXISTS") + print(f" Value type: {type(raw_data['data'])}") + print(f" Value shape: {raw_data['data'].shape}") + else: + print(f" ✗ 'data' key MISSING!") + print(f" Available keys: {list(raw_data.keys())}") + + elif isinstance(raw_data, torch.Tensor): + print(f" ✗ It's a raw Tensor, NOT a dict!") + print(f" Shape: {raw_data.shape}") + print(f" Dtype: {raw_data.dtype}") + print(f" This will FAIL in SpatioTemporalDataset._from_f()") + + else: + print(f" ? Unexpected type: {type(raw_data)}") + + print(f"\n2. WHAT _from_f() WOULD DO:") + print(f" The code does: assert 'data' in f, 'HDF5 file must contain data dataset'") + + if isinstance(raw_data, dict) and "data" in raw_data: + print(f" ✓ Would PASS the assertion") + else: + print(f" ✗ Would FAIL the assertion - KeyError or AssertionError") + + except Exception as e: + print(f" ERROR: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + +print("\n\n" + "=" * 100) +print("SUMMARY") +print("=" * 100) +print(""" +ISSUE: One or both datasets may not match what SpatioTemporalDataset._from_f() expects. + +The code expects: + - When loading .pt file: torch.load() returns a dict with key "data" + - dict["data"] contains the actual tensor + +If one dataset is a raw tensor instead of a dict, you need to wrap it. +If one dataset has different keys, the code needs adjustment. +""") + +EOF diff --git a/land_mask_visualization.png b/land_mask_visualization.png new file mode 100644 index 00000000..130e9aaa Binary files /dev/null and b/land_mask_visualization.png differ diff --git a/land_mask_visualization_transposed.png b/land_mask_visualization_transposed.png new file mode 100644 index 00000000..a17f333b Binary files /dev/null and b/land_mask_visualization_transposed.png differ diff --git a/logging.wandb.enabled=false b/logging.wandb.enabled=false new file mode 100644 index 00000000..e69de29b diff --git a/random_slurm_job_scripts/check_nc_dimensions.sh b/random_slurm_job_scripts/check_nc_dimensions.sh new file mode 100644 index 00000000..209cea55 --- /dev/null +++ b/random_slurm_job_scripts/check_nc_dimensions.sh @@ -0,0 +1,115 @@ +#!/bin/bash +#SBATCH --job-name=check_nc_dims +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos=turing +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=2 +#SBATCH --mem=8G +#SBATCH --time=00:10:00 + +module purge +module load baskerville +module load bask-apps/live + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python << 'PYTHON_EOF' +import xarray as xr +import torch +import numpy as np + +print("=" * 80) +print("CHECKING NETCDF DIMENSIONS vs MASK DIMENSIONS") +print("=" * 80) + +# Load netCDF file +nc_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/osisaf_nh_2018.nc" +print(f"\n1. Loading netCDF: {nc_path}") +ds = xr.open_dataset(nc_path) + +print(f"\nDataset variables: {list(ds.data_vars)}") +print(f"Dataset coords: {list(ds.coords)}") +print(f"Dataset dims: {dict(ds.dims)}") + +# Find SIC variable +sic_var = None +for var_name in ['sic', 'ice_conc', 'sea_ice_concentration', 'concentration']: + if var_name in ds.data_vars: + sic_var = var_name + break +if sic_var is None: + sic_var = list(ds.data_vars)[0] + +print(f"\n2. Using SIC variable: '{sic_var}'") +sic = ds[sic_var] +print(f" Shape: {sic.shape}") +print(f" Dims: {sic.dims}") +print(f" Dtype: {sic.dtype}") + +# Get one timestep to see spatial dimensions +print(f"\n3. Extracting first timestep...") +first_timestep = sic.isel(time=0).values +print(f" Shape of first timestep: {first_timestep.shape}") +print(f" First 3x3 corner:") +print(first_timestep[:3, :3]) + +# Load mask +print(f"\n4. Loading land mask...") +mask_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/land_mask.pt" +mask = torch.load(mask_path) +print(f" Mask shape: {mask.shape}") +print(f" First 3x3 corner of mask:") +print(mask[:3, :3]) + +# Load processed data (if exists) +print(f"\n5. Loading processed PyTorch data...") +try: + data_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic/data.pt" + data_dict = torch.load(data_path) + if isinstance(data_dict, dict) and 'data' in data_dict: + data = data_dict['data'] + else: + data = data_dict + print(f" Data shape: {data.shape}") + print(f" Dims interpretation: (traj, time, spatial_0, spatial_1, channels)") + + # Get first trajectory, first timestep + first_sample = data[0, 0, :, :, 0] + print(f" First sample shape: {first_sample.shape}") + print(f" First 3x3 corner of processed data:") + print(first_sample[:3, :3]) + +except Exception as e: + print(f" Could not load: {e}") + +# Check dimension order in netCDF +print(f"\n6. Analyzing dimension order...") +spatial_dims = [d for d in sic.dims if d != 'time'] +print(f" Spatial dimensions (in order): {spatial_dims}") + +try: + dim0_coords = ds.coords[spatial_dims[0]] + dim1_coords = ds.coords[spatial_dims[1]] + print(f" {spatial_dims[0]} range: {dim0_coords.min().values} to {dim0_coords.max().values}, size={dim0_coords.size}") + print(f" {spatial_dims[1]} range: {dim1_coords.min().values} to {dim1_coords.max().values}, size={dim1_coords.size}") +except: + print(" Could not get coordinate ranges") + +print("\n" + "=" * 80) +print("DIMENSION ANALYSIS SUMMARY") +print("=" * 80) +print(f"NetCDF dimensions: {sic.dims}") +print(f"NetCDF spatial order: {' → '.join(spatial_dims)}") +print(f"Mask shape: {mask.shape}") +print(f"\nPOTENTIAL ISSUE:") +print(f" If netCDF uses (time, yc, xc) or (time, lat, lon)") +print(f" But mask was created with different (row, col) convention") +print(f" Then mask needs transpose to match data spatial layout!") +print("=" * 80) + +PYTHON_EOF + diff --git a/random_slurm_job_scripts/create_subset.sh b/random_slurm_job_scripts/create_subset.sh new file mode 100644 index 00000000..4acfaab6 --- /dev/null +++ b/random_slurm_job_scripts/create_subset.sh @@ -0,0 +1,45 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 00:15:00 +#SBATCH --nodes=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=8G +#SBATCH --job-name=osisaf_subset +#SBATCH --output=logs/subset_%j.out +#SBATCH --error=logs/subset_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import torch +import os + +data_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/test/data.pt" +output_dir = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/test" + +print("Loading data...") +data_dict = torch.load(data_path) +data = data_dict['data'] # Extract tensor from dict + +print(f"Full data shape: {data.shape}") +print(f"Data type: {data.dtype}") + +# Create subset: first 100 samples +subset = data[:100] + +subset_path = os.path.join(output_dir, "data_subset.pt") +torch.save(subset, subset_path) + +print(f"Saved subset → {subset_path}") +print(f"Subset shape: {subset.shape}") +print(f"Subset size: {os.path.getsize(subset_path) / 1e6:.2f} MB") +EOF \ No newline at end of file diff --git a/random_slurm_job_scripts/diagnose_transpose.sh b/random_slurm_job_scripts/diagnose_transpose.sh new file mode 100644 index 00000000..58fa6afa --- /dev/null +++ b/random_slurm_job_scripts/diagnose_transpose.sh @@ -0,0 +1,139 @@ +#!/bin/bash +#SBATCH --job-name=diagnose_transpose +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos=turing +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=2 +#SBATCH --mem=16G +#SBATCH --time=00:10:00 + +module purge +module load baskerville +module load bask-apps/live + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python << 'PYTHON_EOF' +import torch +import xarray as xr +import numpy as np +from einops import rearrange + +print("=" * 80) +print("DIAGNOSING TRANSPOSE ISSUE") +print("=" * 80) + +# 1. Load netCDF and get spatial dimensions +print("\n1. Loading netCDF to understand dimension convention...") +nc_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/osisaf_nh_2018.nc" +ds = xr.open_dataset(nc_path) +sic = ds['ice_conc'] +print(f" NetCDF dims: {sic.dims}") +print(f" NetCDF shape: {sic.shape}") + +# Get one valid timestep (not all zeros) +for i in range(min(10, sic.shape[0])): + sample = sic.isel(time=i).values + if not np.all(sample == 0) and not np.all(np.isnan(sample)): + first_valid = sample + print(f" Using timestep {i} for analysis") + break + +print(f" Sample shape: {first_valid.shape}") +print(f" Sample has NaN: {np.any(np.isnan(first_valid))}") +print(f" Sample value range: {np.nanmin(first_valid):.3f} to {np.nanmax(first_valid):.3f}") + +# 2. Load mask +print("\n2. Loading mask...") +mask_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/land_mask.pt" +mask = torch.load(mask_path).numpy() +print(f" Mask shape: {mask.shape}") +print(f" Mask values: {np.unique(mask)}") + +# 3. Compare corners - if mask matches data, corners should align +print("\n3. Comparing corner patterns...") +print(" Data (yc, xc) top-left 5x5:") +print(first_valid[:5, :5]) +print("\n Mask (?, ?) top-left 5x5:") +print(mask[:5, :5]) +print("\n Mask TRANSPOSED (.T) top-left 5x5:") +print(mask.T[:5, :5]) + +# Find a distinctive pattern in data to match +print("\n4. Finding distinctive ocean/land boundary...") +# Look for transition from ocean (valid data) to land (NaN or 0) +# Check middle row +mid_row_idx = mask.shape[0] // 2 +data_mid_row = first_valid[mid_row_idx, :] +mask_mid_row = mask[mid_row_idx, :] +mask_mid_col = mask[:, mid_row_idx] # if transposed + +print(f" Data middle row (yc={mid_row_idx}): valid count = {np.sum(~np.isnan(data_mid_row))}/{len(data_mid_row)}") +print(f" Mask middle row: ocean count = {np.sum(mask_mid_row==1)}/{len(mask_mid_row)}") +print(f" Mask middle col: ocean count = {np.sum(mask_mid_col==1)}/{len(mask_mid_col)}") + +# 5. Check if data NaN pattern matches mask +print("\n5. Checking NaN/mask alignment...") +is_valid_data = ~np.isnan(first_valid) # True where data is valid (ocean) +is_ocean_mask = mask == 1 # True where mask says ocean + +matches_direct = np.sum(is_valid_data == is_ocean_mask) +total_pixels = mask.size +matches_transposed = np.sum(is_valid_data == is_ocean_mask.T) + +print(f" Pixels matching (mask as-is): {matches_direct}/{total_pixels} ({100*matches_direct/total_pixels:.1f}%)") +print(f" Pixels matching (mask transposed): {matches_transposed}/{total_pixels} ({100*matches_transposed/total_pixels:.1f}%)") + +# 6. Check rearrange behavior +print("\n6. Testing einops rearrange...") +mask_torch = torch.from_numpy(mask) +mask_rearranged = rearrange(mask_torch, 'w h -> 1 1 w h 1') +print(f" Original mask shape: {mask_torch.shape}") +print(f" After rearrange('w h -> 1 1 w h 1'): {mask_rearranged.shape}") +print(f" Interpretation: (batch, time, w, h, channels) = {mask_rearranged.shape}") + +# If we have (yc, xc) and want (batch, time, yc, xc, channels), rearrange is correct +# But if dataset.py expects (batch, time, WIDTH, HEIGHT, channels) and +# unpacks as (width, height), we need to know which is which! + +print("\n" + "=" * 80) +print("DIAGNOSIS SUMMARY") +print("=" * 80) + +if matches_transposed > matches_direct: + print("❌ TRANSPOSE MISMATCH CONFIRMED!") + print(f" Mask needs .T to match data spatial layout") + print(f" Match rate improves from {100*matches_direct/total_pixels:.1f}% to {100*matches_transposed/total_pixels:.1f}%") + print("\n🔍 ROOT CAUSE:") + print(" Option A: Mask was saved with (xc, yc) instead of (yc, xc)") + print(" Option B: Data processing swapped dimensions during conversion") + print(" Option C: Dataset.py unpacks (width, height) in wrong order") +else: + print("✓ No transpose issue detected") + print(f" Mask aligns correctly with data ({100*matches_direct/total_pixels:.1f}% match)") + +print("=" * 80) + +# 7. Detailed dimension tracking +print("\n7. DIMENSION TRACKING:") +print(" NetCDF: (time, yc, xc)") +print(" → get_osisaf_data.py: ensures (time, yc, xc) via ensure_order()") +print(" → Adds channel: (time, yc, xc) → (time, yc, xc, 1)") +print(" → Stacks years: (N_years, 365, yc, xc, 1)") +print(" → Dataset unpacks to:", end="") +print(" (n_traj, n_time, self.width, self.height, n_channels)") +print(f" = (n_traj, n_time, {mask.shape[0]}, {mask.shape[1]}, 1)") +print(f" So: width={mask.shape[0]} (yc dimension), height={mask.shape[1]} (xc dimension)") +print(f"\n Mask shape: {mask.shape} - interpreted as (dim0, dim1)") +print(f" rearrange('w h -> ...') treats:") +print(f" w = mask.shape[0] = {mask.shape[0]}") +print(f" h = mask.shape[1] = {mask.shape[1]}") +print(f"\n If data has (yc, xc) and mask has (yc, xc): ✓ MATCH") +print(f" If data has (yc, xc) and mask has (xc, yc): ❌ MISMATCH") + +PYTHON_EOF + diff --git a/random_slurm_job_scripts/download_osisaf_full.sh b/random_slurm_job_scripts/download_osisaf_full.sh new file mode 100644 index 00000000..12d62262 --- /dev/null +++ b/random_slurm_job_scripts/download_osisaf_full.sh @@ -0,0 +1,100 @@ +#!/bin/bash +#SBATCH --job-name=download_osisaf_full +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos=turing +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --time=12:00:00 + +module purge +module load baskerville +module load bask-apps/live + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python << 'PYTHON_EOF' +import xarray as xr +import os +from pathlib import Path + +# Configuration +OPENDAP_URL = "https://thredds.met.no/thredds/dodsC/osisaf/met.no/reprocessed/ice/conc_450a1_nh_agg" +RAW_DIR = Path("/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf") + +# Create output directory +RAW_DIR.mkdir(parents=True, exist_ok=True) + +print("=" * 80) +print("DOWNLOADING OSI-SAF FULL DATASET (YEAR BY YEAR)") +print("=" * 80) + +# Open the dataset to see available years +print(f"\n1. Connecting to OPeNDAP: {OPENDAP_URL}") +ds = xr.open_dataset(OPENDAP_URL, engine='netcdf4') + +print(f"\n2. Dataset info:") +print(f" Variables: {list(ds.data_vars)}") +print(f" Coordinates: {list(ds.coords)}") +print(f" Time range: {ds.time.min().values} to {ds.time.max().values}") + +# Get available years +years = sorted(set(ds.time.dt.year.values)) +print(f"\n3. Available years: {years[0]} to {years[-1]} ({len(years)} years)") + +# Download each year separately +print(f"\n4. Downloading year by year...") +for year in years: + output_file = RAW_DIR / f"osisaf_nh_{year}.nc" + + # Skip if already exists + if output_file.exists(): + size_mb = output_file.stat().st_size / (1024 * 1024) + print(f" [{year}] Already exists ({size_mb:.1f} MB), skipping...") + continue + + print(f" [{year}] Downloading...") + try: + # Select this year's data + ds_year = ds.sel(time=str(year)) + + # Check if year has data + n_timesteps = len(ds_year.time) + if n_timesteps == 0: + print(f" [{year}] No data available, skipping...") + continue + + print(f" [{year}] Found {n_timesteps} timesteps, saving...") + + # Save to netCDF + ds_year.to_netcdf(output_file, engine='netcdf4') + + size_mb = output_file.stat().st_size / (1024 * 1024) + print(f" [{year}] ✓ Saved ({size_mb:.1f} MB)") + + except Exception as e: + print(f" [{year}] ✗ Error: {e}") + # Remove partial file if it exists + if output_file.exists(): + output_file.unlink() + continue + +print("\n" + "=" * 80) +print("DOWNLOAD COMPLETE") +print("=" * 80) +print(f"Raw data saved to: {RAW_DIR}") + +# List downloaded files +print("\nDownloaded files:") +for f in sorted(RAW_DIR.glob("osisaf_nh_*.nc")): + size_mb = f.stat().st_size / (1024 * 1024) + print(f" {f.name}: {size_mb:.1f} MB") + +print("=" * 80) + +PYTHON_EOF + diff --git a/random_slurm_job_scripts/plot_mask.sh b/random_slurm_job_scripts/plot_mask.sh new file mode 100644 index 00000000..b7c85c96 --- /dev/null +++ b/random_slurm_job_scripts/plot_mask.sh @@ -0,0 +1,57 @@ +#!/bin/bash +#SBATCH --job-name=plot_mask +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos=turing +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=2 +#SBATCH --mem=8G +#SBATCH --time=00:05:00 + +module purge +module load baskerville +module load bask-apps/live + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python << 'PYTHON_EOF' +import torch +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import numpy as np + +# Load mask +mask_path = '/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/land_mask.pt' +mask = torch.load(mask_path) + +print(f"Mask loaded: shape={mask.shape}, dtype={mask.dtype}") + +# Create figure with colorbar +fig, ax = plt.subplots(figsize=(10, 10), dpi=100) + +# Plot binary mask +im = ax.imshow(mask.numpy(), cmap='binary', origin='upper') + +# Add colorbar +cbar = plt.colorbar(im, ax=ax) +cbar.set_label('Mask Value (0=Land, 1=Ocean)', rotation=270, labelpad=20) + +ax.set_title('Sea Ice Land Mask (NH, 432x432)', fontsize=14, fontweight='bold') +ax.set_xlabel('Longitude Index') +ax.set_ylabel('Latitude Index') + +# Save figure +output_path = '/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast/land_mask_visualization.png' +plt.savefig(output_path, dpi=100, bbox_inches='tight') +print(f"\nMask visualization saved to: {output_path}") + +# Print statistics +print(f"\nMask Statistics:") +print(f" Land pixels (0): {torch.sum(mask == 0).item():,} ({100*torch.sum(mask==0).item()/mask.numel():.1f}%)") +print(f" Ocean pixels (1): {torch.sum(mask == 1).item():,} ({100*torch.sum(mask==1).item()/mask.numel():.1f}%)") + +PYTHON_EOF + diff --git a/random_slurm_job_scripts/plot_mask_transpose.sh b/random_slurm_job_scripts/plot_mask_transpose.sh new file mode 100644 index 00000000..2fdd12e9 --- /dev/null +++ b/random_slurm_job_scripts/plot_mask_transpose.sh @@ -0,0 +1,55 @@ +#!/bin/bash +#SBATCH --job-name=plot_mask_transpose +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos=turing +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=2 +#SBATCH --mem=8G +#SBATCH --time=00:05:00 + +module purge +module load baskerville +module load bask-apps/live + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python << 'PYTHON_EOF' +import torch +import matplotlib.pyplot as plt + +# Load mask +mask_path = '/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/land_mask.pt' +mask = torch.load(mask_path) + +print(f"Original mask shape: {mask.shape}") + +# Transpose the mask +mask_transposed = mask.T + +print(f"Transposed mask shape: {mask_transposed.shape}") + +# Create figure +fig, ax = plt.subplots(figsize=(10, 10), dpi=100) + +# Plot transposed binary mask +im = ax.imshow(mask_transposed.numpy(), cmap='binary', origin='upper') + +# Add colorbar +cbar = plt.colorbar(im, ax=ax) +cbar.set_label('Mask Value (0=Land, 1=Ocean)', rotation=270, labelpad=20) + +ax.set_title('Sea Ice Land Mask - TRANSPOSED (NH, 432x432)', fontsize=14, fontweight='bold') +ax.set_xlabel('Longitude Index') +ax.set_ylabel('Latitude Index') + +# Save figure +output_path = '/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast/land_mask_visualization_transposed.png' +plt.savefig(output_path, dpi=100, bbox_inches='tight') +print(f"Transposed mask visualization saved to: {output_path}") + +PYTHON_EOF + diff --git a/random_slurm_job_scripts/transpose_mask.sh b/random_slurm_job_scripts/transpose_mask.sh new file mode 100644 index 00000000..17f50855 --- /dev/null +++ b/random_slurm_job_scripts/transpose_mask.sh @@ -0,0 +1,80 @@ +#!/bin/bash +#SBATCH --job-name=transpose_mask +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos=turing +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=2 +#SBATCH --mem=4G +#SBATCH --time=00:05:00 + +module purge +module load baskerville +module load bask-apps/live + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python << 'PYTHON_EOF' +import torch +import os +from pathlib import Path + +mask_path = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/land_mask.pt" + +print("=" * 80) +print("TRANSPOSING LAND MASK") +print("=" * 80) + +# Load original mask +print(f"\n1. Loading original mask from: {mask_path}") +mask_original = torch.load(mask_path) +print(f" Original shape: {mask_original.shape}") +print(f" Original dtype: {mask_original.dtype}") + +# Create backup +backup_path = mask_path.replace(".pt", "_original_backup.pt") +if not os.path.exists(backup_path): + print(f"\n2. Creating backup at: {backup_path}") + torch.save(mask_original, backup_path) + print(" ✓ Backup created") +else: + print(f"\n2. Backup already exists at: {backup_path}") + +# Transpose the mask +print(f"\n3. Transposing mask...") +mask_transposed = mask_original.T +print(f" Transposed shape: {mask_transposed.shape}") + +# Verify stats are preserved +land_pixels_orig = (mask_original == 0).sum().item() +ocean_pixels_orig = (mask_original == 1).sum().item() +land_pixels_trans = (mask_transposed == 0).sum().item() +ocean_pixels_trans = (mask_transposed == 1).sum().item() + +print(f"\n4. Verifying statistics...") +print(f" Original - Land: {land_pixels_orig:,}, Ocean: {ocean_pixels_orig:,}") +print(f" Transposed - Land: {land_pixels_trans:,}, Ocean: {ocean_pixels_trans:,}") + +if land_pixels_orig == land_pixels_trans and ocean_pixels_orig == ocean_pixels_trans: + print(" ✓ Statistics match (transpose is correct)") +else: + print(" ✗ WARNING: Statistics don't match!") + +# Save transposed mask with original name +print(f"\n5. Saving transposed mask to: {mask_path}") +torch.save(mask_transposed, mask_path) +print(" ✓ Transposed mask saved") + +print("\n" + "=" * 80) +print("MASK TRANSPOSITION COMPLETE") +print("=" * 80) +print(f"Original mask backed up to: {backup_path}") +print(f"Transposed mask saved to: {mask_path}") +print("\nThe model will now use the correctly oriented mask!") +print("=" * 80) + +PYTHON_EOF + diff --git a/resolved_eval_config.yaml b/resolved_eval_config.yaml new file mode 100644 index 00000000..721a6cf9 --- /dev/null +++ b/resolved_eval_config.yaml @@ -0,0 +1,118 @@ +datamodule: + _target_: autocast.data.datamodule.SpatioTemporalDataModule + batch_size: 1 + data_path: /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_osisaf_selectedyears + n_steps_input: 5 + n_steps_output: 1 + num_workers: 4 + stride: 1 + use_normalization: false + verbose: false +eval: + batch_indices: + - 0 + checkpoint: /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast/outputs/seaice/epd_flow_pixels_in2_out1_masked__selectedyears/2026-02-13_15-12-29/encoder_processor_decoder.ckpt + csv_path: null + device: cuda + fps: 5 + free_running_only: true + max_rollout_steps: 25 + metrics: + - mse + - rmse + - vrmse + video_dir: outputs/seaice/epd_flow_pixels_in2_out1_masked__selectedyears/2026-02-13_15-12-29/eval_videos + video_format: mp4 + video_sample_index: 0 +experiment_name: seaice/epd_flow_pixels_in2_out1_masked__selectedyears +logging: + wandb: + config: {} + enabled: true + entity: null + group: null + id: null + job_type: seaice/epd_flow_pixels_in2_out1_masked__selectedyears + log_model: false + mode: online + name: null + notes: null + project: autocast + resume: null + save_dir: null + settings: {} + tags: [] + watch: + log: null + log_freq: 100 +model: + decoder: + _target_: autocast.decoders.identity.IdentityDecoder + encoder: + _target_: autocast.encoders.identity.IdentityEncoder + in_channels: 1 + loss_func: + _target_: torch.nn.MSELoss + processor: + _target_: autocast.processors.masked_flow_matching.MaskedFlowMatchingProcessor + backbone: + _target_: autocast.nn.unet.TemporalUNetBackbone + cond_channels: 1 + global_cond_channels: null + hid_blocks: + - 2 + - 2 + - 2 + hid_channels: + - 32 + - 64 + - 128 + in_channels: 1 + include_global_cond: false + mod_features: 256 + n_steps_input: 5 + n_steps_output: 1 + out_channels: 1 + periodic: false + spatial: 2 + temporal_method: none + flow_ode_steps: 4 + mask_path: /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/land_mask.pt + n_channels_out: 1 + n_steps_output: 1 + train_in_latent_space: true +optimizer: + betas: + - 0.9 + - 0.99 + grad_clip: 1 + learning_rate: 0.0001 + name: adamw_0.0001_cosine + optimizer: adamw + scheduler: cosine + warmup: 0 + weight_decay: 0.0 +output: + checkpoint_name: encoder_processor_decoder.ckpt + checkpoint_path: null + save_config: true + skip_test: false +seed: 42 +trainer: + _target_: lightning.pytorch.trainer.trainer.Trainer + accelerator: gpu + callbacks: + - _target_: lightning.pytorch.callbacks.ModelCheckpoint + every_n_train_steps: 5000 + filename: step-{step} + mode: min + monitor: train_loss + save_on_train_epoch_end: false + save_top_k: 1 + default_root_dir: null + detect_anomaly: false + devices: 1 + enable_checkpointing: true + gradient_clip_val: null + log_every_n_steps: 10 + max_epochs: 40 diff --git a/restructure_osisaf.sh b/restructure_osisaf.sh new file mode 100644 index 00000000..502f71bd --- /dev/null +++ b/restructure_osisaf.sh @@ -0,0 +1,52 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 0:30:00 +#SBATCH --nodes 1 +#SBATCH --gpus 0 +#SBATCH --tasks-per-node 1 +#SBATCH --job-name restructure_osisaf + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +python << 'EOF' +import torch +from pathlib import Path + +base_path = Path("/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all") + +print("Loading train data...") +train_data = torch.load(base_path / "train" / "data.pt") +print(f"Train shape: {train_data.shape}") + +print("Loading valid data...") +valid_data = torch.load(base_path / "valid" / "data.pt") +print(f"Valid shape: {valid_data.shape}") + +print("Loading test data...") +test_data = torch.load(base_path / "test" / "data.pt") +print(f"Test shape: {test_data.shape}") + +# Create the expected dict structure +data_dict = { + "data": { + "train": train_data, + "valid": valid_data, + "test": test_data, + } +} + +# Save as single file +output_path = base_path / "osisaf_nh_sic_all.pt" +print(f"\nSaving restructured data to: {output_path}") +torch.save(data_dict, output_path) +print("✓ Done!") +EOF diff --git a/run_climatology.sh b/run_climatology.sh new file mode 100644 index 00000000..cc40ece0 --- /dev/null +++ b/run_climatology.sh @@ -0,0 +1,26 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 00:20:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --gpus=0 +#SBATCH --mem=64G +#SBATCH --job-name climatology +#SBATCH --output=logs/climatology_%j.out +#SBATCH --error=logs/climatology_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python build_climatology_baseline.py diff --git a/run_download_opendap.sh b/run_download_opendap.sh new file mode 100644 index 00000000..0a85cc1b --- /dev/null +++ b/run_download_opendap.sh @@ -0,0 +1,26 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 03:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --gpus=0 +#SBATCH --mem=64G +#SBATCH --job-name download_osisaf_opendap +#SBATCH --output=logs/download_osisaf_opendap_%j.out +#SBATCH --error=logs/download_osisaf_opendap_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python download_osisaf_opendap.py diff --git a/scripts/create_and_transpose_mask.sh b/scripts/create_and_transpose_mask.sh new file mode 100644 index 00000000..d8dfe54e --- /dev/null +++ b/scripts/create_and_transpose_mask.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --job-name=create_mask +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos=turing +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=2 +#SBATCH --mem=8G +#SBATCH --time=00:10:00 + +module purge +module load baskerville +module load bask-apps/live + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python << 'PYTHON_EOF' +import torch +import numpy as np +import xarray as xr +from pathlib import Path + +print("=" * 80) +print("CREATING AND TRANSPOSING SEA ICE LAND MASK") +print("=" * 80) + +# Step 1: Create mask from raw netCDF +nc_file = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/osisaf_nh_2018.nc" +output_dir = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_osisaf/osisaf_nh_sic_all" +mask_path = Path(output_dir) / "land_mask.pt" + +print(f"\n1. Creating mask from: {nc_file}") +print(f"Loading netCDF file...") +ds = xr.open_dataset(nc_file) +print(f"Dataset variables: {list(ds.data_vars)}") + +# Try to find the sea ice concentration variable +sic_var = None +for var_name in ['sic', 'sea_ice_concentration', 'ice_conc', 'concentration']: + if var_name in ds.data_vars: + sic_var = var_name + break + +if sic_var is None: + sic_var = list(ds.data_vars)[0] + +print(f"Using variable: {sic_var}") +sic_data = ds[sic_var].values +print(f"SIC data shape: {sic_data.shape}") + +# Create mask: 1 for ocean (valid data), 0 for land (invalid/NaN) +if sic_data.ndim == 3: + # Time series data + valid_per_pixel = np.sum(~np.isnan(sic_data), axis=0) + total_timesteps = sic_data.shape[0] + coverage = valid_per_pixel / total_timesteps + mask = (coverage >= 0.5).astype(np.float32) + print(f"Created mask from time series (50% coverage threshold)") +else: + mask = (~np.isnan(sic_data)).astype(np.float32) + print(f"Created mask from static data") + +mask_tensor = torch.from_numpy(mask).float() +print(f"Mask shape: {mask_tensor.shape}") +print(f"Land pixels (0): {(mask_tensor == 0).sum().item():,}") +print(f"Ocean pixels (1): {(mask_tensor == 1).sum().item():,}") + +ds.close() + +# Step 2: Transpose the mask +print(f"\n2. Transposing mask...") +mask_transposed = mask_tensor.T +print(f"Transposed shape: {mask_transposed.shape}") + +# Verify stats are preserved +land_orig = (mask_tensor == 0).sum().item() +ocean_orig = (mask_tensor == 1).sum().item() +land_trans = (mask_transposed == 0).sum().item() +ocean_trans = (mask_transposed == 1).sum().item() + +print(f"\n3. Verifying statistics...") +print(f" Original - Land: {land_orig:,}, Ocean: {ocean_orig:,}") +print(f" Transposed - Land: {land_trans:,}, Ocean: {ocean_trans:,}") + +if land_orig == land_trans and ocean_orig == ocean_trans: + print(" ✓ Statistics match") +else: + print(" ✗ WARNING: Statistics don't match!") + +# Step 3: Save +Path(output_dir).mkdir(parents=True, exist_ok=True) +torch.save(mask_transposed, mask_path) +print(f"\n4. Saved transposed mask to: {mask_path}") + +print("\n" + "=" * 80) +print("MASK CREATION COMPLETE") +print("=" * 80) +print(f"Path: {mask_path}") +print(f"Shape: {mask_transposed.shape}") +print(f"Ready to use in training!") +print("=" * 80) + +PYTHON_EOF diff --git a/scripts/create_mask_from_2020.sh b/scripts/create_mask_from_2020.sh new file mode 100644 index 00000000..6fc1371e --- /dev/null +++ b/scripts/create_mask_from_2020.sh @@ -0,0 +1,93 @@ +#!/bin/bash +#SBATCH --job-name=create_mask_2020 +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos=turing +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=2 +#SBATCH --mem=8G +#SBATCH --time=00:10:00 + +module purge +module load baskerville +module load bask-apps/live + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python << 'PYTHON_EOF' +import torch +import numpy as np +import xarray as xr +from pathlib import Path + +print("=" * 80) +print("CREATING AND TRANSPOSING SEA ICE LAND MASK FROM 2020 DATA") +print("=" * 80) + +# Step 1: Create mask from raw netCDF +nc_file = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/osisaf_nh_2020.nc" +output_dir = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf" +mask_path = Path(output_dir) / "land_mask.pt" + +# Backup existing mask +if mask_path.exists(): + backup_path = mask_path.with_suffix('.pt.backup') + print(f"\n1. Backing up existing mask to: {backup_path}") + import shutil + shutil.copy(mask_path, backup_path) + +print(f"\n2. Creating new mask from: {nc_file}") +print(f"Loading netCDF file...") +ds = xr.open_dataset(nc_file) +print(f"Dataset variables: {list(ds.data_vars)}") + +# Try to find the sea ice concentration variable +sic_var = None +for var_name in ['sic', 'ice_conc', 'sea_ice_concentration', 'concentration']: + if var_name in ds.data_vars: + sic_var = var_name + break + +if sic_var is None: + sic_var = list(ds.data_vars)[0] + +print(f"Using variable: {sic_var}") +sic_data = ds[sic_var].values +print(f"SIC data shape: {sic_data.shape}") + +# Create mask: 1 for ocean (valid data), 0 for land (invalid/NaN) +if sic_data.ndim == 3: + # Time series data + valid_per_pixel = np.sum(~np.isnan(sic_data), axis=0) + total_timesteps = sic_data.shape[0] + coverage = valid_per_pixel / total_timesteps + mask = (coverage >= 0.5).astype(np.float32) + print(f"Created mask from time series (50% coverage threshold)") +else: + mask = (~np.isnan(sic_data)).astype(np.float32) + print(f"Created mask from static data") + +mask_tensor = torch.from_numpy(mask).float() +print(f"Mask shape: {mask_tensor.shape}") +print(f"Land pixels (0): {(mask_tensor == 0).sum().item():,}") +print(f"Ocean pixels (1): {(mask_tensor == 1).sum().item():,}") + +ds.close() + +# Step 2: Save mask (no transpose) +print(f"\n3. Saving mask (original orientation)...") +torch.save(mask_tensor, mask_path) +print(f"Saved mask to: {mask_path}") + +print("\n" + "=" * 80) +print("MASK CREATION COMPLETE (FROM 2020 DATA - NO TRANSPOSE)") +print("=" * 80) +print(f"Path: {mask_path}") +print(f"Shape: {mask_tensor.shape}") +print(f"Ready to use in training!") +print("=" * 80) + +PYTHON_EOF diff --git a/scripts/create_seaice_mask.py b/scripts/create_seaice_mask.py new file mode 100644 index 00000000..74671293 --- /dev/null +++ b/scripts/create_seaice_mask.py @@ -0,0 +1,124 @@ +"""Create a land mask from raw sea ice netCDF data. + +For sea ice data, land regions should be masked out (0), and ocean/valid regions kept (1). +This script reads from the raw OSISAF netCDF file to generate a static land mask. +""" + +import torch +import numpy as np +import xarray as xr +from pathlib import Path + + +def create_seaice_mask_from_nc( + nc_path: str = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/osisaf_nh_2018.nc", + method: str = "data_coverage", + coverage_threshold: float = 0.5, + output_path: str = None, +) -> np.ndarray: + """Create a binary land mask from raw netCDF sea ice data. + + Args: + nc_path: Path to raw netCDF file + method: How to identify land regions: + - "data_coverage": Land = regions with = coverage_threshold).astype(np.float32) + print(f"Coverage threshold: {coverage_threshold}") + print(f"Valid timesteps per pixel: min={valid_per_pixel.min()}, max={valid_per_pixel.max()}, mean={valid_per_pixel.mean():.1f}") + elif method == "nan_based": + # Land = all timesteps are NaN + mask = (~np.all(np.isnan(sic_data), axis=0)).astype(np.float32) + else: + raise ValueError(f"Unknown method for 3D data: {method}") + + elif sic_data.ndim == 2: + print("Processing static data (lat, lon)...") + if method == "data_coverage": + # For static data, look for NaN as missing/land + mask = (~np.isnan(sic_data)).astype(np.float32) + elif method == "nan_based": + mask = (~np.isnan(sic_data)).astype(np.float32) + else: + raise ValueError(f"Unknown method for 2D data: {method}") + else: + raise ValueError(f"Unexpected data shape: {sic_data.shape}") + + # Convert to torch and ensure proper format + mask_tensor = torch.from_numpy(mask).float() + + print(f"\nMask shape: {mask_tensor.shape}") + print(f"Land pixels (0): {(mask_tensor == 0).sum().item()}") + print(f"Ocean pixels (1): {(mask_tensor == 1).sum().item()}") + print(f"Land coverage: {(mask_tensor == 0).float().mean().item() * 100:.1f}%") + print(f"Ocean coverage: {(mask_tensor == 1).float().mean().item() * 100:.1f}%") + + # Save mask + output_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(mask_tensor, output_path) + print(f"\nSaved mask to: {output_path}") + + ds.close() + return mask_tensor.numpy() + + +if __name__ == "__main__": + # Create mask from raw netCDF data + print("=" * 70) + print("Creating sea ice land mask from raw netCDF data...") + print("=" * 70) + + nc_file = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/osisaf_nh_2018.nc" + output_file = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/land_mask.pt" + + print("\nMethod: Data coverage based (50% valid data threshold)") + mask = create_seaice_mask_from_nc( + nc_path=nc_file, + method="data_coverage", + coverage_threshold=0.5, + output_path=output_file, + ) + + print("\n" + "=" * 70) + print(f"Land mask created and saved to: {output_file}") + print("=" * 70) diff --git a/scripts/process_osisaf_full.py b/scripts/process_osisaf_full.py new file mode 100644 index 00000000..3d11477b --- /dev/null +++ b/scripts/process_osisaf_full.py @@ -0,0 +1,145 @@ +"""Process full OSI-SAF dataset into train/valid/test splits. + +Reads individual year netCDF files and creates processed dataset matching +the structure of the 2018-only dataset. +""" + +import os +import numpy as np +import xarray as xr +import torch +from pathlib import Path + + +def process_year(nc_path: Path) -> np.ndarray: + """Process one year of OSI-SAF data. + + Args: + nc_path: Path to year's netCDF file + + Returns: + Processed data array of shape (365, W, H, 1) + """ + print(f" Loading {nc_path.name}...") + ds = xr.open_dataset(nc_path) + + # Get ice concentration variable + sic = ds['ice_conc'] + + # Drop Feb 29 to ensure 365 days + time_index = sic['time'].to_index() + mask = ~((time_index.month == 2) & (time_index.day == 29)) + sic = sic.isel(time=np.where(mask)[0]) + + # Check we have 365 days + if len(sic.time) != 365: + print(f" WARNING: Expected 365 days, got {len(sic.time)}") + return None + + # Get values as numpy array: (time, y, x) + data = sic.values + + # Convert NaN to 0 + data = np.nan_to_num(data, nan=0.0) + + # Add channel dimension: (time, y, x) -> (time, y, x, 1) + data = data[:, :, :, None] + + print(f" Shape: {data.shape}, dtype: {data.dtype}") + + ds.close() + return data + + +def main(): + # Paths + raw_dir = Path("/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf") + output_dir = Path("/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_osisaf_full") + + # Year splits + train_years = list(range(1979, 2011)) # 1979-2010 (32 years) + valid_years = list(range(2011, 2016)) # 2011-2015 (5 years) + test_years = list(range(2016, 2021)) # 2016-2020 (5 years) + + print("=" * 80) + print("PROCESSING FULL OSI-SAF DATASET") + print("=" * 80) + print(f"\nTrain years: {train_years[0]}-{train_years[-1]} ({len(train_years)} years)") + print(f"Valid years: {valid_years[0]}-{valid_years[-1]} ({len(valid_years)} years)") + print(f"Test years: {test_years[0]}-{test_years[-1]} ({len(test_years)} years)") + print(f"\nInput directory: {raw_dir}") + print(f"Output directory: {output_dir}") + + # Create output directories + for split in ['train', 'valid', 'test']: + (output_dir / split).mkdir(parents=True, exist_ok=True) + + # Process each split + splits = { + 'train': train_years, + 'valid': valid_years, + 'test': test_years + } + + for split_name, years in splits.items(): + print(f"\n{'=' * 80}") + print(f"PROCESSING {split_name.upper()} SPLIT ({len(years)} years)") + print('=' * 80) + + year_data_list = [] + + for year in years: + nc_path = raw_dir / f"osisaf_nh_{year}.nc" + + if not nc_path.exists(): + print(f" WARNING: File not found: {nc_path.name}, skipping") + continue + + try: + data = process_year(nc_path) + if data is not None: + year_data_list.append(data) + except Exception as e: + print(f" ERROR processing {year}: {e}") + continue + + if not year_data_list: + print(f" ERROR: No data processed for {split_name} split!") + continue + + # Stack all years: (n_years, 365, W, H, 1) + print(f"\n Stacking {len(year_data_list)} years...") + stacked_data = np.stack(year_data_list, axis=0) + print(f" Final shape: {stacked_data.shape}") + print(f" Data range: [{stacked_data.min():.3f}, {stacked_data.max():.3f}]") + + # Convert to torch tensor + tensor_data = torch.from_numpy(stacked_data).float() + + # Wrap in dictionary to match 2018 dataset structure + data_dict = {"data": tensor_data} + + # Save + output_path = output_dir / split_name / "data.pt" + print(f" Saving to: {output_path}") + torch.save(data_dict, output_path) + + file_size = output_path.stat().st_size / (1024**3) # GB + print(f" ✓ Saved: {file_size:.2f} GB") + + print("\n" + "=" * 80) + print("PROCESSING COMPLETE") + print("=" * 80) + + # Summary + for split_name in ['train', 'valid', 'test']: + data_path = output_dir / split_name / "data.pt" + if data_path.exists(): + size = data_path.stat().st_size / (1024**3) + data_dict = torch.load(data_path) + shape = data_dict['data'].shape + print(f"\n{split_name.upper():>6}: {shape} - {size:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/scripts/process_osisaf_selectedyears.py b/scripts/process_osisaf_selectedyears.py new file mode 100644 index 00000000..e69de29b diff --git a/scripts/regenerate_mask_2020.sh b/scripts/regenerate_mask_2020.sh new file mode 100644 index 00000000..15ef5295 --- /dev/null +++ b/scripts/regenerate_mask_2020.sh @@ -0,0 +1,112 @@ +#!/bin/bash +#SBATCH --job-name=regen_mask_2020 +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos=turing +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=2 +#SBATCH --mem=8G +#SBATCH --time=00:10:00 + +module purge +module load baskerville +module load bask-apps/live + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python << 'PYTHON_EOF' +import torch +import numpy as np +import xarray as xr +from pathlib import Path + +print("=" * 80) +print("REGENERATING LAND MASK FROM 2020 DATA") +print("=" * 80) + +# Use 2020 data instead of 2018 +nc_file = "/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/osisaf_nh_2020.nc" +mask_path = Path("/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/raw_osisaf/land_mask.pt") + +# Backup old mask if exists +if mask_path.exists(): + backup_path = mask_path.with_suffix('.pt.old') + print(f"\n1. Backing up old mask to: {backup_path}") + torch.save(torch.load(mask_path), backup_path) + print(" ✓ Backup created") + +print(f"\n2. Creating mask from: {nc_file}") +print(f" Loading netCDF file...") +ds = xr.open_dataset(nc_file) +print(f" Dataset variables: {list(ds.data_vars)}") + +# Get sea ice concentration variable +sic_var = None +for var_name in ['sic', 'sea_ice_concentration', 'ice_conc', 'concentration']: + if var_name in ds.data_vars: + sic_var = var_name + break + +if sic_var is None: + sic_var = list(ds.data_vars)[0] + +print(f" Using variable: {sic_var}") +sic_data = ds[sic_var].values +print(f" SIC data shape: {sic_data.shape}") + +# Create mask: 1 for ocean (valid data), 0 for land (invalid/NaN) +if sic_data.ndim == 3: + # Time series data + valid_per_pixel = np.sum(~np.isnan(sic_data), axis=0) + total_timesteps = sic_data.shape[0] + coverage = valid_per_pixel / total_timesteps + mask = (coverage >= 0.5).astype(np.float32) + print(f" Created mask from time series (50% coverage threshold)") +else: + mask = (~np.isnan(sic_data)).astype(np.float32) + print(f" Created mask from static data") + +mask_tensor = torch.from_numpy(mask).float() +print(f"\n3. Original mask shape: {mask_tensor.shape}") +print(f" Land pixels (0): {(mask_tensor == 0).sum().item():,}") +print(f" Ocean pixels (1): {(mask_tensor == 1).sum().item():,}") + +ds.close() + +# Transpose the mask +print(f"\n4. Transposing mask...") +mask_transposed = mask_tensor.T +print(f" Transposed shape: {mask_transposed.shape}") + +# Verify stats are preserved +land_orig = (mask_tensor == 0).sum().item() +ocean_orig = (mask_tensor == 1).sum().item() +land_trans = (mask_transposed == 0).sum().item() +ocean_trans = (mask_transposed == 1).sum().item() + +print(f"\n5. Verifying statistics...") +print(f" Original - Land: {land_orig:,}, Ocean: {ocean_orig:,}") +print(f" Transposed - Land: {land_trans:,}, Ocean: {ocean_trans:,}") + +if land_orig == land_trans and ocean_orig == ocean_trans: + print(" ✓ Statistics match") +else: + print(" ✗ WARNING: Statistics don't match!") + +# Save +torch.save(mask_transposed, mask_path) +print(f"\n6. Saved transposed mask to: {mask_path}") + +print("\n" + "=" * 80) +print("MASK REGENERATION COMPLETE") +print("=" * 80) +print(f"Source: osisaf_nh_2020.nc") +print(f"Path: {mask_path}") +print(f"Shape: {mask_transposed.shape}") +print(f"Ready to use in training!") +print("=" * 80) + +PYTHON_EOF diff --git a/scripts/run_process_osisaf_full.sh b/scripts/run_process_osisaf_full.sh new file mode 100644 index 00000000..8568fbd2 --- /dev/null +++ b/scripts/run_process_osisaf_full.sh @@ -0,0 +1,25 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos=turing +#SBATCH --time=02:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=64G +#SBATCH --job-name=process_osisaf_full +#SBATCH --output=logs/process_osisaf_full_%j.out +#SBATCH --error=logs/process_osisaf_full_%j.err + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.11.3-GCCcore-12.3.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python scripts/process_osisaf_full.py diff --git a/scripts/run_process_osisaf_selectedyears.sh b/scripts/run_process_osisaf_selectedyears.sh new file mode 100644 index 00000000..e69de29b diff --git a/slurm_templates/create_seaice_mask.sh b/slurm_templates/create_seaice_mask.sh new file mode 100644 index 00000000..8a56ecba --- /dev/null +++ b/slurm_templates/create_seaice_mask.sh @@ -0,0 +1,27 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --job-name=create_seaice_mask +#SBATCH --time=00:30:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32GB +#SBATCH --output=logs/create_seaice_mask_%j.log +#SBATCH --error=logs/create_seaice_mask_%j.err + +# Activate conda environment +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +# Change to script directory +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +# Create logs directory if it doesn't exist +mkdir -p logs + +# Run the mask creation script +echo "Starting sea ice land mask creation..." +python scripts/create_seaice_mask.py + +echo "Mask creation completed!" diff --git a/src/autocast/configs/datamodule/osisaf_nh_sic.yaml b/src/autocast/configs/datamodule/osisaf_nh_sic.yaml new file mode 100644 index 00000000..349cdeec --- /dev/null +++ b/src/autocast/configs/datamodule/osisaf_nh_sic.yaml @@ -0,0 +1,9 @@ +_target_: autocast.data.datamodule.SpatioTemporalDataModule +data_path: /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_osisaf/osisaf_nh_sic_all +batch_size: 1 +n_steps_input: 5 +n_steps_output: 1 +stride: 1 +verbose: false +use_normalization: false +num_workers: 4 diff --git a/src/autocast/configs/processor/masked_flow_matching.yaml b/src/autocast/configs/processor/masked_flow_matching.yaml new file mode 100644 index 00000000..55e8520f --- /dev/null +++ b/src/autocast/configs/processor/masked_flow_matching.yaml @@ -0,0 +1,9 @@ +defaults: + - /backbone@backbone: unet + - _self_ + +_target_: autocast.processors.masked_flow_matching.MaskedFlowMatchingProcessor +flow_ode_steps: 4 +n_steps_output: null +n_channels_out: null +mask_path: null # User provides this via CLI or config \ No newline at end of file diff --git a/src/autocast/processors/masked_flow_matching.py b/src/autocast/processors/masked_flow_matching.py new file mode 100644 index 00000000..b274b8d5 --- /dev/null +++ b/src/autocast/processors/masked_flow_matching.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import torch +from einops import rearrange +from torch import nn + +from autocast.processors.base import Processor +from autocast.types import EncodedBatch, Tensor + + +class MaskedFlowMatchingProcessor(Processor): + """Processor that wraps a flow-matching generative model.""" + + def __init__( + self, + *, + backbone: nn.Module, + flow_ode_steps: int = 1, + n_steps_output: int = 4, + n_channels_out: int = 1, + mask: Tensor | None = None, + ) -> None: + # Store core hyperparameters and optional prebuilt backbone. + super().__init__() + self.flow_matching_model = backbone + self.flow_ode_steps = max(flow_ode_steps, 1) + self.n_steps_output = n_steps_output + self.n_channels_out = n_channels_out + # Assume 2D mask for now. Store as a buffer so it's moved with the module. + if mask is not None: + if mask.ndim != 2: + raise ValueError(f"Mask must be 2D tensor, got shape {mask.shape}") + mask = rearrange(mask, "w h -> 1 1 w h 1") + # register as buffer so `.to(device)` moves it + self.register_buffer("mask", mask) + else: + self.mask = None + + def flow_field( + self, z: Tensor, t: Tensor, x: Tensor, global_cond: Tensor | None = None + ) -> Tensor: + """Flow matching vector field. + + The vector field over the tangent space of output states (z). + conditioned on input states (x) at time (t). + + Args: + z: Current output states of shape (B, T_out, *spatial, C_out). + t: Time tensor of shape (B,). + x: Conditioning inputs of shape (B, T_in, *spatial, C_in). + global_cond: Optional non-spatial conditioning/modulation tensor. + + Returns + ------- + Time derivative of output states with the same shape as `z`. + """ + return self.flow_matching_model(z, t=t, cond=x, global_cond=global_cond) + + def forward(self, x: Tensor, global_cond: Tensor | None) -> Tensor: + """Alias to map for Lightning/PyTorch compatibility.""" + return self.map(x, global_cond) + + def _apply_mask(self, z: Tensor) -> Tensor: + """Apply mask to tensor if mask is set.""" + if self.mask is None: + return z + + return z * self.mask + + def map(self, x: Tensor, global_cond: Tensor | None) -> Tensor: + """Map inputs states (x) to output states (z) by integrating the flow ODE. + + Starting from noise, Euler-integrate the learned vector field until t=1. + + Args: + x: Conditioning inputs of shape (B, T_in, *spatial, C_in). + + Returns + ------- + Generated outputs of shape (B, T_out, *spatial, C_out). + """ + batch_size = x.shape[0] + device, dtype = x.device, x.dtype + + # Initialize noisy sample and scalar time for each batch element. + spatial_shape = tuple(x.shape[2:-1]) + z_shape = (batch_size, self.n_steps_output, *spatial_shape, self.n_channels_out) + z = torch.randn(z_shape, device=device, dtype=dtype) + t = torch.zeros(batch_size, device=device, dtype=dtype) + + # Simple fixed-step Euler integration over the flow field. + dt = torch.tensor(1.0 / self.flow_ode_steps, device=device, dtype=dtype) + + # Apply mask to inputs and noise if mask is set + x = self._apply_mask(x) + z = self._apply_mask(z) + for _ in range(self.flow_ode_steps): + z = z + dt * self.flow_field(z, t, x, global_cond) + t = t + dt + # Apply mask to updated state of z if mask is set + z = self._apply_mask(z) + return z + + def loss(self, batch: EncodedBatch) -> Tensor: + """Compute flow-matching loss for a batch.""" + input_states = batch.encoded_inputs + target_states = batch.encoded_output_fields + + if ( + target_states.shape[1] != self.n_steps_output + or target_states.shape[-1] != self.n_channels_out + ): + msg = ( + "Target shape does not match configured output dimensions " + f"(expected T_out={self.n_steps_output}, C_out={self.n_channels_out}, " + f"got T_out={target_states.shape[1]}, C_out={target_states.shape[-1]})." + ) + raise ValueError(msg) + + batch_size = target_states.shape[0] + + z0 = torch.randn_like(target_states) + + # Apply mask to initial noise, inputs, and target states if mask is set + z0 = self._apply_mask(z0) + input_states = self._apply_mask(input_states) + target_states = self._apply_mask(target_states) + + t = torch.rand( + batch_size, device=target_states.device, dtype=target_states.dtype + ) + t_broadcast = t.view(batch_size, *([1] * (target_states.ndim - 1))) + zt = (1 - t_broadcast) * z0 + t_broadcast * target_states + + target_velocity = target_states - z0 + v_pred = self.flow_field(zt, t, input_states, global_cond=batch.global_cond) + + squared_diff = (v_pred - target_velocity) ** 2 + if self.mask is not None: + # Compute mean loss over valid masked elements + mask = self.mask.to(dtype=squared_diff.dtype) + return (squared_diff * mask).sum() / mask.expand_as(squared_diff).sum() + + return torch.mean(squared_diff) diff --git a/src/autocast/scripts/setup.py b/src/autocast/scripts/setup.py index 0641b09c..4fce80d3 100644 --- a/src/autocast/scripts/setup.py +++ b/src/autocast/scripts/setup.py @@ -328,7 +328,25 @@ def _build_processor( spatial_resolution=proc_kwargs.get("spatial_resolution"), ) target = processor_config.get("_target_") if processor_config else None + + # Load mask if this is MaskedFlowMatchingProcessor + mask = None + if "masked_flow_matching" in (target or ""): + mask_path = processor_config.get("mask_path") + if mask_path: + log.info("Loading mask from %s", mask_path) + mask = torch.load(mask_path) + # Remove mask_path from config (it's only for setup, not for processor init) + # Temporarily disable struct mode to allow removal + from omegaconf import OmegaConf + struct_mode = OmegaConf.is_struct(processor_config) + OmegaConf.set_struct(processor_config, False) + processor_config.pop("mask_path", None) + OmegaConf.set_struct(processor_config, struct_mode) + filtered_kwargs = _filter_kwargs_for_target(target, proc_kwargs) + if mask is not None: + filtered_kwargs["mask"] = mask return instantiate(processor_config, **filtered_kwargs) diff --git a/src/autocast/utils/plots.py b/src/autocast/utils/plots.py index 16d5b864..9d644df2 100644 --- a/src/autocast/utils/plots.py +++ b/src/autocast/utils/plots.py @@ -146,8 +146,8 @@ def _range_from_arrays(arrays): rows_to_plot.append((pred_uq_batch, pred_uq_label, "inferno")) total_rows = len(rows_to_plot) - fig = plt.figure(figsize=(C * 4, total_rows * 4)) - gs = GridSpec(total_rows, C, figure=fig, hspace=0.3, wspace=0.3) + fig = plt.figure(figsize=(C * 6, total_rows * 4)) + gs = GridSpec(total_rows, C, figure=fig, hspace=0.3, wspace=0.4) axes = [] images = [] diff --git a/test_baseline.sh b/test_baseline.sh new file mode 100644 index 00000000..9478c6e5 --- /dev/null +++ b/test_baseline.sh @@ -0,0 +1,30 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 2:00:00 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --tasks-per-node 4 +#SBATCH --job-name test_baseline_epd + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast +export PYTHONPATH=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast/src:$PYTHONPATH + +echo "Testing baseline (regular processor, no mask)..." + +python -m autocast.scripts.train.encoder_processor_decoder \ + --config-path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast/configs \ + datamodule=osisaf_nh_sic \ + trainer.max_epochs=1 \ + trainer.accelerator=gpu \ + trainer.devices=1 diff --git a/test_config.sh b/test_config.sh new file mode 100644 index 00000000..673060e2 --- /dev/null +++ b/test_config.sh @@ -0,0 +1,21 @@ +#!/bin/bash +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast + +python -m autocast.scripts.train.encoder_processor_decoder \ + --config-path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast/configs \ + datamodule=osisaf_nh_sic \ + model.processor=masked_flow_matching \ + +model.processor.mask_path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/land_mask.pt \ + trainer.max_epochs=1 \ + trainer.accelerator=gpu \ + trainer.devices=1 \ + hydra.verbose=true \ + --cfg job 2>&1 | grep -A 5 "processor\|backbone" | head -20 diff --git a/train_masked.sh b/train_masked.sh new file mode 100644 index 00000000..a2f39842 --- /dev/null +++ b/train_masked.sh @@ -0,0 +1,30 @@ +#!/bin/bash +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 2:00:00 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --tasks-per-node 4 +#SBATCH --job-name train_masked_epd + +set -e + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 + +source /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/miniconda3/etc/profile.d/conda.sh +conda activate autocast + +cd /bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast +export PYTHONPATH=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast/src:$PYTHONPATH + +python -m autocast.scripts.train.encoder_processor_decoder \ + --config-path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/code/autocast/configs \ + datamodule=osisaf_nh_sic \ + model.processor=masked_flow_matching \ + +model.processor.mask_path=/bask/projects/v/vjgo8416-ai-phy-sys/qqaa9560/data/seaice/processed_autocast/osisaf_nh_sic_all/land_mask.pt \ + trainer.max_epochs=2 \ + trainer.accelerator=gpu \ + trainer.devices=1 diff --git a/trainer.max_epochs=2 b/trainer.max_epochs=2 new file mode 100644 index 00000000..e69de29b diff --git a/url b/url new file mode 100644 index 00000000..e69de29b