@@ -26,22 +26,20 @@ class ERA5Dataset(Dataset):
2626 data:
2727 source:
2828 ERA5:
29- prognostic:
30- vars_3D: ['T', 'U', 'V', 'Q']
31- vars_2D: ['T500', 'U500', 'V500', 'Q500' ,'Z500', 'tsi', 't2m','SP']
32- path: "<path to prognostic>"
33- diagnostic:
34- vars_3D: ['T', 'U', 'V', 'Q']
35- vars_2D: ['T500', 'U500', 'V500', 'Q500' ,'Z500', 'tsi', 't2m','SP']
36- path: "<path to diagnostic>"
37- static:
38- vars_3D: ['T', 'U', 'V', 'Q']
39- vars_2D: ['T500', 'U500', 'V500', 'Q500' ,'Z500', 'tsi', 't2m','SP']
40- path: "<path to static>"
41- dynamic_forcing:
42- vars_3D: ['T', 'U', 'V', 'Q']
43- vars_2D: ['T500', 'U500', 'V500', 'Q500' ,'Z500', 'tsi', 't2m','SP']
44- path: "<path to dynamic forcing>"
29+ level_coord: "level"
30+ levels: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137]
31+ variables:
32+ prognostic:
33+ vars_3D: ['T', 'U', 'V', 'Q']
34+ vars_2D: ['T500', 'U500', 'V500', 'Q500' ,'Z500', 'tsi', 't2m','SP']
35+ path: "<path to prognostic>"
36+ diagnostic: null
37+ static:
38+ vars_2D: ['Z_GDS4_SFC','LSM']
39+ path: "<path to static>"
40+ dynamic_forcing:
41+ vars_2D: ['tsi']
42+ path: "<path to dynamic forcing>"
4543
4644 start_datetime: "2017-01-01"
4745 end_datetime: "2019-12-31"
@@ -50,13 +48,15 @@ class ERA5Dataset(Dataset):
5048 Assumptions:
5149 1) The data MUST be stored in yearly zarr or netCDF files with a unique 4-digit year (YYYY) in the file name
5250 2) "time" dimension / coordinate is present
53- 3) "level" dimension name representing the vertical level
54- 4) Dimension order of ('time', level', 'latitude', 'longitude') for 3D vars (remove level for 2D)
51+ 3) "level" or "pressure" coordinate name representing the vertical level
52+ 4) Dimension order of ('time', level/pressure ', 'latitude', 'longitude') for 3D vars (remove level/pressure for 2D)
5553 5) Data should be chunked efficiently for a fast read (recommend small chunks across time dimension).
5654 """
5755
5856 def __init__ (self , config , return_target = False ):
5957 self .source_name = "ERA5"
58+ self .level_coord = config ["source" ]["ERA5" ]["level_coord" ]
59+ self .levels = config ["source" ]["ERA5" ]["levels" ]
6060 self .return_target = return_target
6161 self .dt = pd .Timedelta (config ["timestep" ])
6262 self .num_forecast_steps = config ["forecast_len" ] + 1
@@ -66,8 +66,9 @@ def __init__(self, config, return_target=False):
6666 self .years = [str (y ) for y in self .datetimes .year ]
6767 self .file_dict = {}
6868 self .var_dict = {}
69+ self .variable_meta = self ._build_var_metadata (config )
6970
70- for field_type , d in config ["source" ][self .source_name ].items ():
71+ for field_type , d in config ["source" ][self .source_name ][ "variables" ] .items ():
7172 if field_type not in VALID_FIELD_TYPES :
7273 raise KeyError (
7374 f"Unknown field_type '{ field_type } ' in config['source']['{ self .source_name } ']. "
@@ -140,8 +141,9 @@ def __getitem__(self, args):
140141 if key in self .file_dict .keys ():
141142 self ._open_ds_extract_fields (key , t_target , return_data , is_target = True )
142143 self ._pop_and_merge_targets (return_data )
144+ return_data ["metadata" ]["target_datetime" ] = int (t_target .value )
143145
144- self . _add_metadata ( return_data , t , t_target )
146+ return_data [ "metadata" ][ "input_datetime" ] = int ( t . value )
145147
146148 return return_data
147149
@@ -164,23 +166,21 @@ def _open_ds_extract_fields(self, field_type, t, return_data, is_target=False):
164166 ds = dataset .sel (time = t )
165167 else :
166168 ds = dataset
167-
168169 ds_all_vars = ds [self .var_dict [field_type ]["vars_3D" ] + self .var_dict [field_type ]["vars_2D" ]]
169170
170171 ds_3D = ds_all_vars [self .var_dict [field_type ]["vars_3D" ]]
171172 ds_2D = ds_all_vars [self .var_dict [field_type ]["vars_2D" ]]
172- data_np , meta = self ._reshape_and_concat (ds_3D , ds_2D )
173+ data_np = self ._reshape_and_concat (ds_3D , ds_2D )
173174
174175 if is_target :
175176 if field_type == "prognostic" :
176177 return_data ["target_prognostic" ] = torch .tensor (data_np ).float ()
177178 elif field_type == "diagnostic" :
178179 return_data ["target_diagnostic" ] = torch .tensor (data_np ).float ()
180+
179181 else :
180182 return_data [field_type ] = torch .tensor (data_np ).float ()
181183
182- return_data ["metadata" ][f"{ field_type } _var_order" ] = meta
183-
184184 def _reshape_and_concat (self , ds_3D , ds_2D ):
185185 """
186186 Stack 3D variables along level and variable, concatenate with 2D variables, and reorder dimensions.
@@ -190,38 +190,49 @@ def _reshape_and_concat(self, ds_3D, ds_2D):
190190 ds_2D (xr.Dataset): Xarray dataset with 2D spatial variables
191191 """
192192 data_list = []
193- meta_3D , meta_2D = [], []
194193
195194 if ds_3D :
196- data_3D = ds_3D .to_array ().stack ({"level_var" : ["variable" , "level" ]})
197- meta_3D = data_3D .level_var .values .tolist ()
195+ data_3D = (
196+ ds_3D .sel ({self .level_coord : self .levels })
197+ .to_array ()
198+ .stack ({"level_var" : ["variable" , self .level_coord ]})
199+ )
198200 data_3D = np .expand_dims (data_3D .values .transpose (2 , 0 , 1 ), axis = 1 )
199201 data_list .append (data_3D )
200202
201203 if ds_2D :
202204 data_2D = ds_2D .to_array ()
203- meta_2D = data_2D ["variable" ].values .tolist ()
204205 data_2D = np .expand_dims (data_2D , axis = 1 )
205206 data_list .append (data_2D )
206207
207208 combined_data = np .concatenate (data_list , axis = 0 )
208- meta = meta_3D + meta_2D
209209
210- return combined_data , meta
210+ return combined_data
211211
212- def _add_metadata (self , return_data , t , t_target = None ):
213- """
214- Update metadata dictionary
212+ def _build_var_metadata (self , config ):
213+ """Build variable order metadata"""
215214
216- Args:
217- return_data (dict): Return dictionary
218- t (int): Time step
219- t_target: Target time step or None
220- """
221- return_data ["metadata" ]["input_datetime" ] = int (t .value )
215+ var_meta = {}
216+ source_cfg = config ["source" ][self .source_name ]
217+ levels = source_cfg .get ("levels" , [])
218+ variables = source_cfg .get ("variables" , {}) or {}
222219
223- if self .return_target :
224- return_data ["metadata" ]["target_datetime" ] = int (t_target .value )
220+ for field_type , spec in variables .items ():
221+ if spec is None :
222+ continue
223+
224+ var_meta [field_type ] = []
225+
226+ # Expand 3D variables over levels
227+ for v in spec .get ("vars_3D" ) or []:
228+ for lev in levels :
229+ var_meta [field_type ].append (f"{ self .source_name } _{ v } _{ lev } " )
230+
231+ # Add 2D variables directly
232+ for v in spec .get ("vars_2D" ) or []:
233+ var_meta [field_type ].append (f"{ self .source_name } _{ v } " )
234+
235+ return var_meta
225236
226237 def _convert_cf_time (self , ts ):
227238 """
0 commit comments