Skip to content

Commit 0d7f65c

Browse files
authored
Merge pull request #280 from NCAR/loader_update
Update era5.py
2 parents dac3dbe + 16a8d96 commit 0d7f65c

File tree

4 files changed

+110
-89
lines changed

4 files changed

+110
-89
lines changed

config/era5_new_data_config.yaml

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
data:
22
source:
33
ERA5:
4-
prognostic:
5-
# upper-air variables
6-
vars_3D: ['U','V','T','Q']
7-
vars_2D: ['SP','t2m', 'V500','U500','T500','Z500','Q500']
8-
path: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/y_TOTAL*'
9-
10-
diagnostic: null
11-
12-
dynamic_forcing:
13-
vars_2D: ['tsi']
14-
path: '/glade/derecho/scratch/dgagne/credit_solar_nc_1h_0.25deg/*.nc'
15-
16-
static:
17-
vars_2D: ['Z_GDS4_SFC','LSM']
18-
path: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
4+
level_coord: "level"
5+
levels: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137]
6+
variables:
7+
prognostic:
8+
vars_3D: ['U','V','T','Q']
9+
vars_2D: ['SP','t2m', 'V500','U500','T500','Z500','Q500']
10+
path: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/y_TOTAL*'
11+
12+
diagnostic: null
13+
14+
dynamic_forcing:
15+
vars_2D: ['tsi']
16+
path: '/glade/derecho/scratch/dgagne/credit_solar_nc_1h_0.25deg/*.nc'
17+
18+
static:
19+
vars_2D: ['Z_GDS4_SFC','LSM']
20+
path: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
1921

2022
start_datetime: "2021-12-31"
2123
end_datetime: "2022-01-05"
@@ -186,4 +188,4 @@ pbs: #derecho
186188
ncpus: 64
187189
ngpus: 4
188190
mem: '480GB'
189-
queue: 'main'
191+
queue: 'main'

credit/datasets/era5.py

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,20 @@ class ERA5Dataset(Dataset):
2626
data:
2727
source:
2828
ERA5:
29-
prognostic:
30-
vars_3D: ['T', 'U', 'V', 'Q']
31-
vars_2D: ['T500', 'U500', 'V500', 'Q500' ,'Z500', 'tsi', 't2m','SP']
32-
path: "<path to prognostic>"
33-
diagnostic:
34-
vars_3D: ['T', 'U', 'V', 'Q']
35-
vars_2D: ['T500', 'U500', 'V500', 'Q500' ,'Z500', 'tsi', 't2m','SP']
36-
path: "<path to diagnostic>"
37-
static:
38-
vars_3D: ['T', 'U', 'V', 'Q']
39-
vars_2D: ['T500', 'U500', 'V500', 'Q500' ,'Z500', 'tsi', 't2m','SP']
40-
path: "<path to static>"
41-
dynamic_forcing:
42-
vars_3D: ['T', 'U', 'V', 'Q']
43-
vars_2D: ['T500', 'U500', 'V500', 'Q500' ,'Z500', 'tsi', 't2m','SP']
44-
path: "<path to dynamic forcing>"
29+
level_coord: "level"
30+
levels: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137]
31+
variables:
32+
prognostic:
33+
vars_3D: ['T', 'U', 'V', 'Q']
34+
vars_2D: ['T500', 'U500', 'V500', 'Q500' ,'Z500', 'tsi', 't2m','SP']
35+
path: "<path to prognostic>"
36+
diagnostic: null
37+
static:
38+
vars_2D: ['Z_GDS4_SFC','LSM']
39+
path: "<path to static>"
40+
dynamic_forcing:
41+
vars_2D: ['tsi']
42+
path: "<path to dynamic forcing>"
4543
4644
start_datetime: "2017-01-01"
4745
end_datetime: "2019-12-31"
@@ -50,13 +48,15 @@ class ERA5Dataset(Dataset):
5048
Assumptions:
5149
1) The data MUST be stored in yearly zarr or netCDF files with a unique 4-digit year (YYYY) in the file name
5250
2) "time" dimension / coordinate is present
53-
3) "level" dimension name representing the vertical level
54-
4) Dimension order of ('time', level', 'latitude', 'longitude') for 3D vars (remove level for 2D)
51+
3) "level" or "pressure" coordinate name representing the vertical level
52+
4) Dimension order of ('time', level/pressure', 'latitude', 'longitude') for 3D vars (remove level/pressure for 2D)
5553
5) Data should be chunked efficiently for a fast read (recommend small chunks across time dimension).
5654
"""
5755

5856
def __init__(self, config, return_target=False):
5957
self.source_name = "ERA5"
58+
self.level_coord = config["source"]["ERA5"]["level_coord"]
59+
self.levels = config["source"]["ERA5"]["levels"]
6060
self.return_target = return_target
6161
self.dt = pd.Timedelta(config["timestep"])
6262
self.num_forecast_steps = config["forecast_len"] + 1
@@ -66,8 +66,9 @@ def __init__(self, config, return_target=False):
6666
self.years = [str(y) for y in self.datetimes.year]
6767
self.file_dict = {}
6868
self.var_dict = {}
69+
self.variable_meta = self._build_var_metadata(config)
6970

70-
for field_type, d in config["source"][self.source_name].items():
71+
for field_type, d in config["source"][self.source_name]["variables"].items():
7172
if field_type not in VALID_FIELD_TYPES:
7273
raise KeyError(
7374
f"Unknown field_type '{field_type}' in config['source']['{self.source_name}']. "
@@ -140,8 +141,9 @@ def __getitem__(self, args):
140141
if key in self.file_dict.keys():
141142
self._open_ds_extract_fields(key, t_target, return_data, is_target=True)
142143
self._pop_and_merge_targets(return_data)
144+
return_data["metadata"]["target_datetime"] = int(t_target.value)
143145

144-
self._add_metadata(return_data, t, t_target)
146+
return_data["metadata"]["input_datetime"] = int(t.value)
145147

146148
return return_data
147149

@@ -164,23 +166,21 @@ def _open_ds_extract_fields(self, field_type, t, return_data, is_target=False):
164166
ds = dataset.sel(time=t)
165167
else:
166168
ds = dataset
167-
168169
ds_all_vars = ds[self.var_dict[field_type]["vars_3D"] + self.var_dict[field_type]["vars_2D"]]
169170

170171
ds_3D = ds_all_vars[self.var_dict[field_type]["vars_3D"]]
171172
ds_2D = ds_all_vars[self.var_dict[field_type]["vars_2D"]]
172-
data_np, meta = self._reshape_and_concat(ds_3D, ds_2D)
173+
data_np = self._reshape_and_concat(ds_3D, ds_2D)
173174

174175
if is_target:
175176
if field_type == "prognostic":
176177
return_data["target_prognostic"] = torch.tensor(data_np).float()
177178
elif field_type == "diagnostic":
178179
return_data["target_diagnostic"] = torch.tensor(data_np).float()
180+
179181
else:
180182
return_data[field_type] = torch.tensor(data_np).float()
181183

182-
return_data["metadata"][f"{field_type}_var_order"] = meta
183-
184184
def _reshape_and_concat(self, ds_3D, ds_2D):
185185
"""
186186
Stack 3D variables along level and variable, concatenate with 2D variables, and reorder dimensions.
@@ -190,38 +190,49 @@ def _reshape_and_concat(self, ds_3D, ds_2D):
190190
ds_2D (xr.Dataset): Xarray dataset with 2D spatial variables
191191
"""
192192
data_list = []
193-
meta_3D, meta_2D = [], []
194193

195194
if ds_3D:
196-
data_3D = ds_3D.to_array().stack({"level_var": ["variable", "level"]})
197-
meta_3D = data_3D.level_var.values.tolist()
195+
data_3D = (
196+
ds_3D.sel({self.level_coord: self.levels})
197+
.to_array()
198+
.stack({"level_var": ["variable", self.level_coord]})
199+
)
198200
data_3D = np.expand_dims(data_3D.values.transpose(2, 0, 1), axis=1)
199201
data_list.append(data_3D)
200202

201203
if ds_2D:
202204
data_2D = ds_2D.to_array()
203-
meta_2D = data_2D["variable"].values.tolist()
204205
data_2D = np.expand_dims(data_2D, axis=1)
205206
data_list.append(data_2D)
206207

207208
combined_data = np.concatenate(data_list, axis=0)
208-
meta = meta_3D + meta_2D
209209

210-
return combined_data, meta
210+
return combined_data
211211

212-
def _add_metadata(self, return_data, t, t_target=None):
213-
"""
214-
Update metadata dictionary
212+
def _build_var_metadata(self, config):
213+
"""Build variable order metadata"""
215214

216-
Args:
217-
return_data (dict): Return dictionary
218-
t (int): Time step
219-
t_target: Target time step or None
220-
"""
221-
return_data["metadata"]["input_datetime"] = int(t.value)
215+
var_meta = {}
216+
source_cfg = config["source"][self.source_name]
217+
levels = source_cfg.get("levels", [])
218+
variables = source_cfg.get("variables", {}) or {}
222219

223-
if self.return_target:
224-
return_data["metadata"]["target_datetime"] = int(t_target.value)
220+
for field_type, spec in variables.items():
221+
if spec is None:
222+
continue
223+
224+
var_meta[field_type] = []
225+
226+
# Expand 3D variables over levels
227+
for v in spec.get("vars_3D") or []:
228+
for lev in levels:
229+
var_meta[field_type].append(f"{self.source_name}_{v}_{lev}")
230+
231+
# Add 2D variables directly
232+
for v in spec.get("vars_2D") or []:
233+
var_meta[field_type].append(f"{self.source_name}_{v}")
234+
235+
return var_meta
225236

226237
def _convert_cf_time(self, ts):
227238
"""

tests/era5_dataset_test.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,26 @@ def minimal_config():
8080
"end_datetime": "2023-01-05",
8181
"source": {
8282
"ERA5": {
83-
"prognostic": {
84-
"vars_3D": ["T", "U"],
85-
"vars_2D": ["SP"],
86-
"path": "/fake/*.zarr",
87-
},
88-
"dynamic_forcing": {
89-
"vars_2D": ["tsi"],
90-
"path": "/fake/*.zarr",
91-
},
92-
"static": {
93-
"vars_2D": ["LSM"],
94-
"path": "/fake/*.zarr",
95-
},
96-
"diagnostic": {
97-
"vars_2D": ["TP"],
98-
"path": "/fake/*.zarr",
83+
"level_coord": "level",
84+
"levels": [1000, 850, 500, 300],
85+
"variables": {
86+
"prognostic": {
87+
"vars_3D": ["T", "U"],
88+
"vars_2D": ["SP"],
89+
"path": "/fake/*.zarr",
90+
},
91+
"dynamic_forcing": {
92+
"vars_2D": ["tsi"],
93+
"path": "/fake/*.zarr",
94+
},
95+
"static": {
96+
"vars_2D": ["LSM"],
97+
"path": "/fake/*.zarr",
98+
},
99+
"diagnostic": {
100+
"vars_2D": ["TP"],
101+
"path": "/fake/*.zarr",
102+
},
99103
},
100104
}
101105
},

tests/sampler_test.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,26 @@ def minimal_config():
8181
"end_datetime": "2023-01-05",
8282
"source": {
8383
"ERA5": {
84-
"prognostic": {
85-
"vars_3D": ["T", "U"],
86-
"vars_2D": ["SP"],
87-
"path": "/fake/*.zarr",
88-
},
89-
"dynamic_forcing": {
90-
"vars_2D": ["tsi"],
91-
"path": "/fake/*.zarr",
92-
},
93-
"static": {
94-
"vars_2D": ["LSM"],
95-
"path": "/fake/*.zarr",
96-
},
97-
"diagnostic": {
98-
"vars_2D": ["TP"],
99-
"path": "/fake/*.zarr",
84+
"level_coord": "level",
85+
"levels": [1000, 850, 500, 300],
86+
"variables": {
87+
"prognostic": {
88+
"vars_3D": ["T", "U"],
89+
"vars_2D": ["SP"],
90+
"path": "/fake/*.zarr",
91+
},
92+
"dynamic_forcing": {
93+
"vars_2D": ["tsi"],
94+
"path": "/fake/*.zarr",
95+
},
96+
"static": {
97+
"vars_2D": ["LSM"],
98+
"path": "/fake/*.zarr",
99+
},
100+
"diagnostic": {
101+
"vars_2D": ["TP"],
102+
"path": "/fake/*.zarr",
103+
},
100104
},
101105
}
102106
},

0 commit comments

Comments
 (0)