Skip to content

Commit 2046f8b

Browse files
authored
Merge pull request #311 from NatLabRockies/bnb/collection
nc collection for arbitrary chunk shapes
2 parents bfcbef2 + 97ed8fd commit 2046f8b

File tree

9 files changed

+116
-76
lines changed

9 files changed

+116
-76
lines changed

sup3r/pipeline/forward_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,8 @@ def run_chunk(
670670
data=output_data,
671671
features=lowered(model.hr_out_features),
672672
lat_lon=chunk.hr_lat_lon,
673+
row_inds=chunk.row_inds,
674+
col_inds=chunk.col_inds,
673675
times=chunk.hr_times,
674676
out_file=chunk.out_file,
675677
meta_data=meta,

sup3r/pipeline/strategy.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class ForwardPassChunk:
4646
hr_lat_lon: Union[np.ndarray, da.core.Array]
4747
hr_times: pd.DatetimeIndex
4848
gids: Union[np.ndarray, da.core.Array]
49+
row_inds: np.ndarray
50+
col_inds: np.ndarray
4951
out_file: str
5052
pad_width: tuple[tuple, tuple, tuple]
5153
index: int
@@ -458,6 +460,16 @@ def hr_lat_lon(self):
458460
)
459461
return OutputHandler.get_lat_lon(lr_lat_lon, shape)
460462

463+
@cached_property
464+
def grid_inds(self):
465+
"""Get row and column indices for the full high resolution grid. This
466+
is used to collect spatially contiguous data for stitching output
467+
chunks back together."""
468+
shape = self.hr_lat_lon.shape[:-1]
469+
row_inds = np.arange(shape[0])
470+
col_inds = np.arange(shape[1])
471+
return row_inds, col_inds
472+
461473
@cached_property
462474
def out_files(self):
463475
"""Get list of output file names for each file chunk forward pass."""
@@ -580,6 +592,8 @@ def init_chunk(self, chunk_index=0):
580592
hr_times=OutputHandler.get_times(
581593
lr_times, self.t_enhance * len(lr_times)
582594
),
595+
row_inds=self.grid_inds[0][hr_slice[0]],
596+
col_inds=self.grid_inds[1][hr_slice[1]],
583597
gids=self.gids[hr_slice[:2]],
584598
out_file=self.out_files[chunk_index],
585599
pad_width=self.fwp_slicer.extra_padding[chunk_index],

sup3r/postprocessing/collectors/h5.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from rex.utilities.loggers import init_logger
1111
from scipy.spatial import KDTree
1212

13+
from sup3r.preprocessing import Loader
1314
from sup3r.preprocessing.utilities import _mem_check
1415
from sup3r.utilities.utilities import get_dset_attrs, get_tmp_file
1516
from sup3r.writers import RexOutputs
@@ -93,13 +94,15 @@ def get_coordinate_indices(cls, target_meta, full_meta, threshold=1e-4):
9394
threshold : float
9495
Threshold distance for finding target coordinates within full meta
9596
"""
96-
ll2 = np.vstack(
97-
(full_meta.latitude.values, full_meta.longitude.values)
98-
).T
97+
ll2 = np.vstack((
98+
full_meta.latitude.values,
99+
full_meta.longitude.values,
100+
)).T
99101
tree = KDTree(ll2)
100-
targets = np.vstack(
101-
(target_meta.latitude.values, target_meta.longitude.values)
102-
).T
102+
targets = np.vstack((
103+
target_meta.latitude.values,
104+
target_meta.longitude.values,
105+
)).T
103106
_, indices = tree.query(targets, distance_upper_bound=threshold)
104107
indices = indices[indices < len(full_meta)]
105108
return indices
@@ -723,7 +726,7 @@ def collect(
723726
cls,
724727
file_paths,
725728
out_file,
726-
features,
729+
features='all',
727730
max_workers=None,
728731
log_level=None,
729732
log_file=None,
@@ -746,23 +749,16 @@ def collect(
746749
``*_{temporal_chunk_index}_{spatial_chunk_index}.h5``.
747750
out_file : str
748751
File path of final output file.
749-
features : list
750-
List of dsets to collect
752+
features : list | str
753+
List of dsets to collect. If 'all' then all datasets will be
754+
collected
751755
max_workers : int | None
752756
Number of workers to use in parallel. 1 runs serial,
753757
None will use all available workers.
754758
log_level : str | None
755759
Desired log level, None will not initialize logging.
756760
log_file : str | None
757761
Target log file. None logs to stdout.
758-
write_status : bool
759-
Flag to write status file once complete if running from pipeline.
760-
job_name : str
761-
Job name for status file if running from pipeline.
762-
pipeline_step : str, optional
763-
Name of the pipeline step being run. If ``None``, the
764-
``pipeline_step`` will be set to ``"collect``, mimicking old reV
765-
behavior. By default, ``None``.
766762
target_meta_file : str
767763
Path to target final meta containing coordinates to keep from the
768764
full file list collected meta. This can be but is not necessarily a
@@ -796,6 +792,11 @@ def collect(
796792
os.makedirs(os.path.dirname(out_file), exist_ok=True)
797793

798794
collector = cls(file_paths)
795+
features = (
796+
Loader(collector.flist[0]).features
797+
if features == 'all'
798+
else features
799+
)
799800
logger.info(
800801
'Collecting %s files to %s', len(collector.flist), out_file
801802
)

sup3r/postprocessing/collectors/nc.py

Lines changed: 17 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import xarray as xr
1010
from rex.utilities.loggers import init_logger
1111

12-
from sup3r.preprocessing.loaders import Loader
12+
from sup3r.preprocessing import Loader
1313
from sup3r.preprocessing.names import Dimension
1414
from sup3r.writers import Cacher
1515

@@ -32,18 +32,9 @@ def collect(
3232
overwrite=True,
3333
res_kwargs=None,
3434
cacher_kwargs=None,
35-
is_regular_grid=True,
3635
):
3736
"""Collect data files from a dir to one output file.
3837
39-
TODO: For a regular grid (lat values are constant across lon and vice
40-
versa) collecting lat / lon chunks is supported. For curvilinear grids
41-
only collection of chunks that are split by latitude are supported.
42-
This should be generalized to allow for any spatial chunking and any
43-
dimension. I think this would require a new file naming scheme with a
44-
spatial index for both latitude and longitude or checking each chunk
45-
to see how they are split.
46-
4738
Filename requirements:
4839
- Should end with ".nc"
4940
@@ -61,23 +52,12 @@ def collect(
6152
Desired log level, None will not initialize logging.
6253
log_file : str | None
6354
Target log file. None logs to stdout.
64-
write_status : bool
65-
Flag to write status file once complete if running from pipeline.
66-
job_name : str
67-
Job name for status file if running from pipeline.
6855
overwrite : bool
6956
Whether to overwrite existing output file
7057
res_kwargs : dict | None
7158
Dictionary of kwargs to pass to xarray.open_mfdataset.
7259
cacher_kwargs : dict | None
7360
Dictionary of kwargs to pass to Cacher._write_single.
74-
is_regular_grid : bool
75-
Whether the data is on a regular grid. If True then spatial chunks
76-
can be combined across both latitude and longitude. If False then
77-
spatial chunks must all have the same longitude values to be
78-
combined. If you need completely general chunk collection then
79-
you should write chunks to `h5` files and use
80-
:class:`sup3r.postprocessing.collectors.h5.CollectorH5`.
8161
"""
8262
logger.info(f'Initializing collection for file_paths={file_paths}')
8363

@@ -97,33 +77,18 @@ def collect(
9777
logger.info(f'overwrite=True, removing {out_file}.')
9878
os.remove(out_file)
9979

100-
spatial_chunks = collector.group_spatial_chunks()
101-
10280
if not os.path.exists(out_file):
103-
res_kwargs = res_kwargs or {
104-
'combine': 'nested',
105-
'concat_dim': Dimension.TIME,
106-
}
107-
for s_idx, sfiles in spatial_chunks.items():
108-
schunk = Loader(sfiles, res_kwargs=res_kwargs)
109-
spatial_chunks[s_idx] = schunk
110-
111-
# Set lat / lon as 1D arrays if regular grid and get the
112-
# xr.Dataset _ds
113-
if is_regular_grid:
114-
spatial_chunks = {
115-
s_idx: schunk.set_regular_grid()._ds
116-
for s_idx, schunk in spatial_chunks.items()
117-
}
118-
out = xr.combine_by_coords(
119-
spatial_chunks.values(), combine_attrs='override'
120-
)
121-
122-
else:
123-
out = xr.concat(
124-
[sc._ds for sc in spatial_chunks.values()],
125-
dim=Dimension.SOUTH_NORTH,
126-
)
81+
dsets = list(
82+
collector.group_spatial_chunks(res_kwargs=res_kwargs).values()
83+
)
84+
85+
# Reset coords so that they are data_vars and can be combined
86+
# across chunks. This is needed because coords can be 2d arrays,
87+
# which can't be used to combine chunks. After combination, set
88+
# them back to coords.
89+
dsets = [ds.reset_coords(Dimension.coords_2d()) for ds in dsets]
90+
out = xr.combine_by_coords(dsets, combine_attrs='override')
91+
out = out.set_coords(Dimension.coords_2d())
12792

12893
cacher_kwargs = cacher_kwargs or {}
12994
Cacher._write_single(
@@ -135,13 +100,14 @@ def collect(
135100

136101
logger.info('Finished file collection.')
137102

138-
def group_spatial_chunks(self):
139-
"""Group same spatial chunks together so each entry has same spatial
140-
footprint but different times"""
103+
def group_spatial_chunks(self, res_kwargs=None):
104+
"""Group same spatial chunks together to get list of files with same
105+
spatial footprint but different times. Return `Loader` instances for
106+
each spatial chunk with combined times."""
141107
chunks = {}
142108
for file in self.flist:
143109
_, s_idx = self.get_chunk_indices(file)
144110
chunks[s_idx] = [*chunks.get(s_idx, []), file]
145111
for k, v in chunks.items():
146-
chunks[k] = sorted(v)
112+
chunks[k] = Loader(sorted(v), res_kwargs=res_kwargs)
147113
return chunks

sup3r/utilities/pytest/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ def make_collect_chunks(td, ext='h5'):
347347
out_file,
348348
meta_data=model_meta_data,
349349
max_workers=1,
350+
row_inds=np.arange(shape[0])[s1_hr],
351+
col_inds=np.arange(shape[1])[s2_hr],
350352
gids=gids[s1_hr, s2_hr],
351353
)
352354

sup3r/writers/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,8 @@ def _write_output(
567567
invert_uv=False,
568568
nn_fill=False,
569569
max_workers=None,
570+
row_inds=None,
571+
col_inds=None,
570572
gids=None,
571573
):
572574
"""Write output to file with specified times and lats/lons"""
@@ -584,6 +586,8 @@ def write_output(
584586
nn_fill=False,
585587
max_workers=None,
586588
gids=None,
589+
row_inds=None,
590+
col_inds=None,
587591
):
588592
"""Write forward pass output to file
589593
@@ -615,6 +619,14 @@ def write_output(
615619
gids : list
616620
List of coordinate indices used to label each lat lon pair and to
617621
help with spatial chunk data collection
622+
row_inds : np.ndarray
623+
Array of row indices for the full high resolution grid. This is
624+
used to collect spatially contiguous data for stitching output
625+
chunks back together.
626+
col_inds : np.ndarray
627+
Array of column indices for the full high resolution grid. This is
628+
used to collect spatially contiguous data for stitching output
629+
chunks back together.
618630
"""
619631
lat_lon = cls.get_lat_lon(low_res_lat_lon, data.shape[:2])
620632
times = cls.get_times(low_res_times, data.shape[-2])
@@ -628,5 +640,7 @@ def write_output(
628640
invert_uv=invert_uv,
629641
nn_fill=nn_fill,
630642
max_workers=max_workers,
643+
row_inds=row_inds,
644+
col_inds=col_inds,
631645
gids=gids,
632646
)

sup3r/writers/h5.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def _write_output(
2828
invert_uv=False,
2929
nn_fill=False,
3030
max_workers=None,
31+
row_inds=None,
32+
col_inds=None,
3133
gids=None,
3234
):
3335
"""Write forward pass output to H5 file
@@ -57,6 +59,14 @@ def _write_output(
5759
neighbour or cap to limits
5860
max_workers : int | None
5961
Max workers to use for inverse transform.
62+
row_inds : np.ndarray
63+
Array of row indices for the full high resolution grid. This is
64+
used to collect spatially contiguous data for stitching output
65+
chunks back together.
66+
col_inds : np.ndarray
67+
Array of column indices for the full high resolution grid. This is
68+
used to collect spatially contiguous data for stitching output
69+
chunks back together.
6070
gids : list
6171
List of coordinate indices used to label each lat lon pair and to
6272
help with spatial chunk data collection
@@ -84,8 +94,20 @@ def _write_output(
8494
if gids is not None
8595
else np.arange(np.prod(lat_lon.shape[:-1]))
8696
)
97+
row_inds = (
98+
row_inds if row_inds is not None else np.arange(lat_lon.shape[0])
99+
)
100+
col_inds = (
101+
col_inds if col_inds is not None else np.arange(lat_lon.shape[1])
102+
)
87103
meta = pd.DataFrame({
88104
'gid': gids.flatten(),
105+
'row_ind': np.repeat(
106+
row_inds[:, np.newaxis], len(col_inds), axis=1
107+
).flatten(),
108+
'col_ind': np.repeat(
109+
col_inds[np.newaxis, :], len(row_inds), axis=0
110+
).flatten(),
89111
'latitude': lat_lon[..., 0].flatten(),
90112
'longitude': lat_lon[..., 1].flatten(),
91113
})

sup3r/writers/nc.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def _write_output(
3030
max_workers=None,
3131
invert_uv=False,
3232
nn_fill=False,
33+
row_inds=None,
34+
col_inds=None,
3335
gids=None,
3436
):
3537
"""Write forward pass output to NETCDF file
@@ -59,6 +61,14 @@ def _write_output(
5961
nn_fill : bool
6062
Whether to fill data outside of limits with nearest neighbour or
6163
cap to limits
64+
row_inds : np.ndarray
65+
Array of row indices for the full high resolution grid. This is
66+
used to help with spatial chunk data collection and should be
67+
included if the output data is spatially chunked.
68+
col_inds : np.ndarray
69+
Array of column indices for the full high resolution grid. This is
70+
used to help with spatial chunk data collection and should be
71+
included if the output data is spatially chunked.
6272
gids : list
6373
List of coordinate indices used to label each lat lon pair and to
6474
help with spatial chunk data collection
@@ -77,9 +87,12 @@ def _write_output(
7787
Dimension.LATITUDE: (Dimension.dims_2d(), lat_lon[:, :, 0]),
7888
Dimension.LONGITUDE: (Dimension.dims_2d(), lat_lon[:, :, 1]),
7989
}
80-
data_vars = {}
8190
if gids is not None:
82-
data_vars = {'gids': (Dimension.dims_2d(), gids)}
91+
coords['gids'] = (Dimension.dims_2d(), gids)
92+
if row_inds is not None and col_inds is not None:
93+
for dim, inds in zip(Dimension.dims_2d(), [row_inds, col_inds]):
94+
coords[dim] = (dim, inds)
95+
data_vars = {}
8396
for i, f in enumerate(features):
8497
data_vars[f] = (
8598
(Dimension.TIME, *Dimension.dims_2d()),
@@ -95,6 +108,6 @@ def _write_output(
95108
Cacher._write_single(
96109
out_file=out_file,
97110
data=ds,
98-
features=features,
111+
features=list(data_vars.keys()),
99112
max_workers=max_workers,
100113
)

0 commit comments

Comments
 (0)