Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,8 @@ def run_chunk(
data=output_data,
features=lowered(model.hr_out_features),
lat_lon=chunk.hr_lat_lon,
row_inds=chunk.row_inds,
col_inds=chunk.col_inds,
times=chunk.hr_times,
out_file=chunk.out_file,
meta_data=meta,
Expand Down
14 changes: 14 additions & 0 deletions sup3r/pipeline/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ForwardPassChunk:
hr_lat_lon: Union[np.ndarray, da.core.Array]
hr_times: pd.DatetimeIndex
gids: Union[np.ndarray, da.core.Array]
row_inds: np.ndarray
col_inds: np.ndarray
out_file: str
pad_width: tuple[tuple, tuple, tuple]
index: int
Expand Down Expand Up @@ -458,6 +460,16 @@ def hr_lat_lon(self):
)
return OutputHandler.get_lat_lon(lr_lat_lon, shape)

@cached_property
def grid_inds(self):
"""Get row and column indices for the full high resolution grid. This
is used to collect spatially contiguous data for stitching output
chunks back together."""
shape = self.hr_lat_lon.shape[:-1]
row_inds = np.arange(shape[0])
col_inds = np.arange(shape[1])
return row_inds, col_inds

@cached_property
def out_files(self):
"""Get list of output file names for each file chunk forward pass."""
Expand Down Expand Up @@ -580,6 +592,8 @@ def init_chunk(self, chunk_index=0):
hr_times=OutputHandler.get_times(
lr_times, self.t_enhance * len(lr_times)
),
row_inds=self.grid_inds[0][hr_slice[0]],
col_inds=self.grid_inds[1][hr_slice[1]],
gids=self.gids[hr_slice[:2]],
out_file=self.out_files[chunk_index],
pad_width=self.fwp_slicer.extra_padding[chunk_index],
Expand Down
22 changes: 15 additions & 7 deletions sup3r/postprocessing/collectors/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rex.utilities.loggers import init_logger
from scipy.spatial import KDTree

from sup3r.preprocessing import Loader
from sup3r.preprocessing.utilities import _mem_check
from sup3r.utilities.utilities import get_dset_attrs, get_tmp_file
from sup3r.writers import RexOutputs
Expand Down Expand Up @@ -93,13 +94,15 @@ def get_coordinate_indices(cls, target_meta, full_meta, threshold=1e-4):
threshold : float
Threshold distance for finding target coordinates within full meta
"""
ll2 = np.vstack(
(full_meta.latitude.values, full_meta.longitude.values)
).T
ll2 = np.vstack((
full_meta.latitude.values,
full_meta.longitude.values,
)).T
tree = KDTree(ll2)
targets = np.vstack(
(target_meta.latitude.values, target_meta.longitude.values)
).T
targets = np.vstack((
target_meta.latitude.values,
target_meta.longitude.values,
)).T
_, indices = tree.query(targets, distance_upper_bound=threshold)
indices = indices[indices < len(full_meta)]
return indices
Expand Down Expand Up @@ -723,7 +726,7 @@ def collect(
cls,
file_paths,
out_file,
features,
features='all',
max_workers=None,
log_level=None,
log_file=None,
Expand Down Expand Up @@ -796,6 +799,11 @@ def collect(
os.makedirs(os.path.dirname(out_file), exist_ok=True)

collector = cls(file_paths)
features = (
Loader(collector.flist[0]).features
if features == 'all'
else features
)
logger.info(
'Collecting %s files to %s', len(collector.flist), out_file
)
Expand Down
51 changes: 7 additions & 44 deletions sup3r/postprocessing/collectors/nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,10 @@ def collect(
log_file=None,
overwrite=True,
res_kwargs=None,
cacher_kwargs=None,
is_regular_grid=True,
cacher_kwargs=None
):
"""Collect data files from a dir to one output file.

TODO: For a regular grid (lat values are constant across lon and vice
versa) collecting lat / lon chunks is supported. For curvilinear grids
only collection of chunks that are split by latitude are supported.
This should be generalized to allow for any spatial chunking and any
dimension. I think this would require a new file naming scheme with a
spatial index for both latitude and longitude or checking each chunk
to see how they are split.

Filename requirements:
- Should end with ".nc"

Expand Down Expand Up @@ -71,13 +62,6 @@ def collect(
Dictionary of kwargs to pass to xarray.open_mfdataset.
cacher_kwargs : dict | None
Dictionary of kwargs to pass to Cacher._write_single.
is_regular_grid : bool
Whether the data is on a regular grid. If True then spatial chunks
can be combined across both latitude and longitude. If False then
spatial chunks must all have the same longitude values to be
combined. If you need completely general chunk collection then
you should write chunks to `h5` files and use
:class:`sup3r.postprocessing.collectors.h5.CollectorH5`.
"""
logger.info(f'Initializing collection for file_paths={file_paths}')

Expand All @@ -97,33 +81,12 @@ def collect(
logger.info(f'overwrite=True, removing {out_file}.')
os.remove(out_file)

spatial_chunks = collector.group_spatial_chunks()
spatial_chunks = collector.group_spatial_chunks(res_kwargs=res_kwargs)

if not os.path.exists(out_file):
res_kwargs = res_kwargs or {
'combine': 'nested',
'concat_dim': Dimension.TIME,
}
for s_idx, sfiles in spatial_chunks.items():
schunk = Loader(sfiles, res_kwargs=res_kwargs)
spatial_chunks[s_idx] = schunk

# Set lat / lon as 1D arrays if regular grid and get the
# xr.Dataset _ds
if is_regular_grid:
spatial_chunks = {
s_idx: schunk.set_regular_grid()._ds
for s_idx, schunk in spatial_chunks.items()
}
out = xr.combine_by_coords(
spatial_chunks.values(), combine_attrs='override'
)

else:
out = xr.concat(
[sc._ds for sc in spatial_chunks.values()],
dim=Dimension.SOUTH_NORTH,
)
dsets = list(spatial_chunks.values())
dsets = [ds.reset_coords(Dimension.coords_2d()) for ds in dsets]
out = xr.combine_by_coords(dsets, combine_attrs='override')

cacher_kwargs = cacher_kwargs or {}
Cacher._write_single(
Expand All @@ -135,13 +98,13 @@ def collect(

logger.info('Finished file collection.')

def group_spatial_chunks(self):
def group_spatial_chunks(self, res_kwargs=None):
"""Group same spatial chunks together so each entry has same spatial
footprint but different times"""
chunks = {}
for file in self.flist:
_, s_idx = self.get_chunk_indices(file)
chunks[s_idx] = [*chunks.get(s_idx, []), file]
for k, v in chunks.items():
chunks[k] = sorted(v)
chunks[k] = Loader(sorted(v), res_kwargs=res_kwargs)
return chunks
14 changes: 14 additions & 0 deletions sup3r/writers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,8 @@ def _write_output(
invert_uv=False,
nn_fill=False,
max_workers=None,
row_inds=None,
col_inds=None,
gids=None,
):
"""Write output to file with specified times and lats/lons"""
Expand All @@ -584,6 +586,8 @@ def write_output(
nn_fill=False,
max_workers=None,
gids=None,
row_inds=None,
col_inds=None,
):
"""Write forward pass output to file

Expand Down Expand Up @@ -615,6 +619,14 @@ def write_output(
gids : list
List of coordinate indices used to label each lat lon pair and to
help with spatial chunk data collection
row_inds : np.ndarray
Array of row indices for the full high resolution grid. This is
used to collect spatially contiguous data for stitching output
chunks back together.
col_inds : np.ndarray
Array of column indices for the full high resolution grid. This is
used to collect spatially contiguous data for stitching output
chunks back together.
"""
lat_lon = cls.get_lat_lon(low_res_lat_lon, data.shape[:2])
times = cls.get_times(low_res_times, data.shape[-2])
Expand All @@ -628,5 +640,7 @@ def write_output(
invert_uv=invert_uv,
nn_fill=nn_fill,
max_workers=max_workers,
row_inds=row_inds,
col_inds=col_inds,
gids=gids,
)
18 changes: 18 additions & 0 deletions sup3r/writers/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def _write_output(
invert_uv=False,
nn_fill=False,
max_workers=None,
row_inds=None,
col_inds=None,
gids=None,
):
"""Write forward pass output to H5 file
Expand Down Expand Up @@ -57,6 +59,14 @@ def _write_output(
neighbour or cap to limits
max_workers : int | None
Max workers to use for inverse transform.
row_inds : np.ndarray
Array of row indices for the full high resolution grid. This is
used to collect spatially contiguous data for stitching output
chunks back together.
col_inds : np.ndarray
Array of column indices for the full high resolution grid. This is
used to collect spatially contiguous data for stitching output
chunks back together.
gids : list
List of coordinate indices used to label each lat lon pair and to
help with spatial chunk data collection
Expand Down Expand Up @@ -84,8 +94,16 @@ def _write_output(
if gids is not None
else np.arange(np.prod(lat_lon.shape[:-1]))
)
row_inds = (
row_inds if row_inds is not None else np.arange(lat_lon.shape[0])
)
col_inds = (
col_inds if col_inds is not None else np.arange(lat_lon.shape[1])
)
meta = pd.DataFrame({
'gid': gids.flatten(),
'row_ind': row_inds.flatten(),
'col_ind': col_inds.flatten(),
'latitude': lat_lon[..., 0].flatten(),
'longitude': lat_lon[..., 1].flatten(),
})
Expand Down
28 changes: 24 additions & 4 deletions sup3r/writers/nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def _write_output(
max_workers=None,
invert_uv=False,
nn_fill=False,
row_inds=None,
col_inds=None,
gids=None,
):
"""Write forward pass output to NETCDF file
Expand Down Expand Up @@ -59,6 +61,14 @@ def _write_output(
nn_fill : bool
Whether to fill data outside of limits with nearest neighbour or
cap to limits
row_inds : np.ndarray
Array of row indices for the full high resolution grid. This is
used to help with spatial chunk data collection and should be
included if the output data is spatially chunked.
col_inds : np.ndarray
Array of column indices for the full high resolution grid. This is
used to help with spatial chunk data collection and should be
included if the output data is spatially chunked.
gids : list
List of coordinate indices used to label each lat lon pair and to
help with spatial chunk data collection
Expand All @@ -72,20 +82,30 @@ def _write_output(
max_workers=max_workers,
)

coords = {
data_vars = {
Dimension.TIME: times,
Dimension.LATITUDE: (Dimension.dims_2d(), lat_lon[:, :, 0]),
Dimension.LONGITUDE: (Dimension.dims_2d(), lat_lon[:, :, 1]),
}
data_vars = {}
if gids is not None:
data_vars = {'gids': (Dimension.dims_2d(), gids)}
data_vars['gids'] = (Dimension.dims_2d(), gids)
if row_inds is not None and col_inds is not None:
for dim, inds in zip(Dimension.dims_2d(), [row_inds, col_inds]):
data_vars[dim] = (dim, inds)
for i, f in enumerate(features):
data_vars[f] = (
(Dimension.TIME, *Dimension.dims_2d()),
np.transpose(data[..., i], axes=(2, 0, 1)).astype(np.float32),
)

if all(d in data_vars for d in Dimension.dims_2d()):
coords = {dim: data_vars.pop(dim) for dim in Dimension.dims_2d()}
else:
coords = {
coord: data_vars.pop(coord) for coord in Dimension.coords_2d()
}
coords[Dimension.TIME] = data_vars.pop(Dimension.TIME)

attrs = meta_data or {}
now = dt.now(datetime.timezone.utc).isoformat()
attrs['date_modified'] = now
Expand All @@ -95,6 +115,6 @@ def _write_output(
Cacher._write_single(
out_file=out_file,
data=ds,
features=features,
features=list(data_vars.keys()),
max_workers=max_workers,
)
3 changes: 1 addition & 2 deletions tests/output/test_output_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def test_general_nc_collect():
out[-1],
)

CollectorNC.collect(out_files, fp_out, features=features,
is_regular_grid=True)
CollectorNC.collect(out_files, fp_out, features=features)

with Loader(fp_out) as res:
assert np.array_equal(hr_times, res.time_index.values)
Expand Down
Loading