Skip to content

Commit ffe6dd8

Browse files
committed
feat: enabled forward passes with obs models to run without obs data; added nn_fill on ouput data outside of allowed range, instead of clipping.
1 parent a809f45 commit ffe6dd8

5 files changed

Lines changed: 44 additions & 15 deletions

File tree

sup3r/models/abstract.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,15 @@
3636

3737
logger = logging.getLogger(__name__)
3838

39-
40-
SUP3R_LAYERS = (
41-
Sup3rAdder,
42-
Sup3rConcat,
39+
SUP3R_OBS_LAYERS = (
4340
Sup3rConcatObs,
4441
Sup3rConcatEmbeddedObs,
4542
Sup3rConcatWeightedObs,
4643
Sup3rConcatWeightedObsWithEmbedding,
4744
)
4845

46+
SUP3R_LAYERS = (Sup3rAdder, Sup3rConcat, *SUP3R_OBS_LAYERS)
47+
4948

5049
# pylint: disable=E1101,W0201,E0203
5150
class AbstractSingleModel(ABC, TensorboardMixIn):
@@ -1041,7 +1040,15 @@ def generate(
10411040
try:
10421041
for i, layer in enumerate(self.generator.layers[1:]):
10431042
layer_num = i + 1
1044-
if isinstance(layer, SUP3R_LAYERS):
1043+
is_obs_layer = isinstance(layer, SUP3R_OBS_LAYERS)
1044+
is_exo_layer = isinstance(layer, SUP3R_LAYERS)
1045+
if is_obs_layer and layer.name not in exogenous_data:
1046+
msg = (f'Observation data not given for {layer.name}. '
1047+
'Will run forward pass without it.')
1048+
logger.warning(msg)
1049+
warn(msg)
1050+
hi_res = layer(hi_res)
1051+
elif is_exo_layer:
10451052
msg = (
10461053
f'layer.name = {layer.name} does not match any '
10471054
'features in exogenous_data '

sup3r/postprocessing/collectors/nc.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import logging
77
import os
88

9+
import xarray as xr
910
from rex.utilities.loggers import init_logger
1011

1112
from sup3r.preprocessing.cachers import Cacher
13+
from sup3r.preprocessing.names import Dimension
1214
from sup3r.utilities.utilities import xr_open_mfdataset
1315

1416
from .base import BaseCollector
@@ -32,6 +34,12 @@ def collect(
3234
):
3335
"""Collect data files from a dir to one output file.
3436
37+
TODO: This assumes that if there is any spatial chunking it is split
38+
by latitude. This should be generalized to allow for any spatial
39+
chunking and any dimension. This will either require a new file
40+
naming scheme with a spatial index for both latitude and
41+
longitude or checking each chunk to see how they are split.
42+
3543
Filename requirements:
3644
- Should end with ".nc"
3745
@@ -76,10 +84,19 @@ def collect(
7684
logger.info(f'overwrite=True, removing {out_file}.')
7785
os.remove(out_file)
7886

87+
spatial_chunks = collector.group_spatial_chunks()
88+
7989
tmp_file = out_file + '.tmp'
8090
if not os.path.exists(tmp_file):
81-
res_kwargs = res_kwargs or {}
82-
out = xr_open_mfdataset(collector.flist, **res_kwargs)
91+
res_kwargs = res_kwargs or {
92+
'combine': 'nested',
93+
'concat_dim': Dimension.TIME,
94+
}
95+
for s_idx in spatial_chunks:
96+
spatial_chunks[s_idx] = xr_open_mfdataset(
97+
spatial_chunks[s_idx], **res_kwargs
98+
)
99+
out = xr.concat(spatial_chunks.values(), dim=Dimension.SOUTH_NORTH)
83100
Cacher.write_netcdf(tmp_file, data=out, features=features)
84101

85102
os.replace(tmp_file, out_file)
@@ -88,12 +105,12 @@ def collect(
88105
logger.info('Finished file collection.')
89106

90107
def group_spatial_chunks(self):
91-
"""Group same spatial chunks together so each chunk has same spatial
108+
"""Group same spatial chunks together so each entry has same spatial
92109
footprint but different times"""
93110
chunks = {}
94111
for file in self.flist:
95-
s_chunk = file.split('_')[0]
96-
dirname = os.path.dirname(file)
97-
s_file = os.path.join(dirname, f's_{s_chunk}.nc')
98-
chunks[s_file] = [*chunks.get(s_file, []), s_file]
112+
_, s_idx = self.get_chunk_indices(file)
113+
chunks[s_idx] = [*chunks.get(s_idx, []), file]
114+
for k, v in chunks.items():
115+
chunks[k] = sorted(v)
99116
return chunks

sup3r/preprocessing/accessor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,11 @@ def interpolate_na(self, **kwargs):
372372
'use_coordinate', False
373373
)
374374
self._ds[feat] = self._ds[feat].interpolate_na(**kwargs)
375-
else:
375+
elif np.isnan(self._ds[feat]).any():
376376
msg = (
377377
'No dim given for interpolate_na. This will use nearest '
378-
'neighbor fill, which could take some time.'
378+
f'neighbor fill to interpolate {feat}, which could take '
379+
'some time.'
379380
)
380381
logger.warning(msg)
381382
warn(msg)

sup3r/preprocessing/data_handlers/exo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, steps):
9797
warn(msg)
9898
steps_list = entry.get('steps', [entry])
9999
for i, step in enumerate(steps_list):
100-
msg = (f'ExoData entry for {feat}, step #{i+1} has no '
100+
msg = (f'ExoData entry for {feat}, step #{i + 1}, has no '
101101
'"combine_type" key. Assuming this is for a '
102102
'layer combination.')
103103
if 'combine_type' not in step:

sup3r/utilities/utilities.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def enforce_limits(features, data, nn_fill=False):
156156
max_val = OUTPUT_ATTRS[dset_name].get('max', np.inf)
157157
min_val = OUTPUT_ATTRS[dset_name].get('min', -np.inf)
158158
enforcing_msg = f'Enforcing range of ({min_val}, {max_val}) for "{fn}"'
159+
if nn_fill:
160+
enforcing_msg += ' with nearest neighbor interpolation.'
161+
else:
162+
enforcing_msg += ' with clipping.'
159163

160164
f_max = data[..., fidx].max()
161165
f_min = data[..., fidx].min()

0 commit comments

Comments
 (0)