diff --git a/reproject/common.py b/reproject/common.py index a3e46c495..814dc93ca 100644 --- a/reproject/common.py +++ b/reproject/common.py @@ -60,6 +60,7 @@ def _reproject_dispatcher( shape_out, wcs_out, block_size=None, + non_reprojected_dims=None, array_out=None, return_footprint=True, output_footprint=None, @@ -93,6 +94,11 @@ def _reproject_dispatcher( the block size automatically determined. If ``block_size`` is not specified or set to `None`, the reprojection will not be carried out in blocks. + non_reprojected_dims : tuple + Dimensions that should not be reprojected but instead for which a + 1-to-1 mapping between input and output pixel space should be assumed. + By default, this is any leading extra dimensions if the input WCS has + fewer dimensions than the input data. array_out : `~numpy.ndarray`, optional An array in which to store the reprojected data. This can be any numpy array including a memory map, which may be helpful when dealing with @@ -143,6 +149,19 @@ def _reproject_dispatcher( if reproject_func_kwargs is None: reproject_func_kwargs = {} + # For now, we are quite restrictive in what non_reprojected_dims can + # be, but it is designed so that if we wanted we could support more use + # cases in future. For now, it has to be a tuple where each element is + # sequential from zero, e.g. (0,) or (0, 1) or (0, 1, 2) + + if non_reprojected_dims is None: + n_dim_reproject = min(wcs_in.low_level_wcs.pixel_n_dim, wcs_out.low_level_wcs.pixel_n_dim) + else: + if non_reprojected_dims == tuple(range(len(non_reprojected_dims))): + n_dim_reproject = len(shape_out) - len(non_reprojected_dims) + else: + raise ValueError("non_reprojected_dims should be a tuple with values increasing sequentially from zero") + # We set up a global temporary directory since this will be used e.g. to # store memory mapped Numpy arrays and zarr arrays. @@ -206,9 +225,41 @@ def _reproject_dispatcher( # shape_out will be the full size of the output array as this is updated # in parse_output_projection, even if shape_out was originally passed in as # the shape of a single image. - broadcasting = wcs_in.low_level_wcs.pixel_n_dim < len(shape_out) - logger.info(f"Broadcasting is {'' if broadcasting else 'not '}being used") + broadcasting = n_dim_reproject < len(shape_out) + + logger.info(f"Broadcasting is {'' if broadcasting else 'not '}being used, reprojecting last {n_dim_reproject} axes") + + # Output shape should match input shape for any ignored dimensions + # TODO: check for shape_out not matching shape_in along broadcasted dimensions + + shape_in = array_in.shape + + if shape_out[:-n_dim_reproject] != shape_in[:-n_dim_reproject]: + raise ValueError("Input shape should match output shape for non-reprojected dimensions") + + if len(block_size) > len(shape_out): + raise ValueError( + f"block_size {block_size} cannot have more elements " + f"than the dimensionality of the output ({len(shape_out)})" + ) + + + if len(block_size) != n_dim_reproject and len(block_size) != len(shape_out): + raise ValueError( + f"block_size {block_size} should have either {n_dim_reproject} or {len(shape_out)} elements" + ) + + if len(block_size) == n_dim_reproject: + block_size = (-1,) * (len(shape_out) - n_dim_reproject) + tuple(block_size) + + block_size = [(block_size[i] if block_size[i] != -1 else shape_out[i]) for i in range(len(block_size))] + + block_size = tuple(block_size) + shape_out = tuple(shape_out) + + # TODO: replace block size of -1 by actual value for logic below to work + # TODO: re-implement block_size auto # Check block size and determine whether block size indicates we should # parallelize over broadcasted dimension. The logic is as follows: if @@ -220,32 +271,15 @@ def _reproject_dispatcher( # don't make any assumptions for now and assume a single chunk in the # missing dimensions. broadcasted_parallelization = False - if broadcasting and block_size is not None and block_size != "auto": - if len(block_size) == len(shape_out): - if ( - block_size[-wcs_in.low_level_wcs.pixel_n_dim :] - == shape_out[-wcs_in.low_level_wcs.pixel_n_dim :] - ): - broadcasted_parallelization = True - block_size = ( - block_size[: -wcs_in.low_level_wcs.pixel_n_dim] - + (-1,) * wcs_in.low_level_wcs.pixel_n_dim - ) - else: - for i in range(len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim): - if block_size[i] != -1 and block_size[i] != shape_out[i]: - raise ValueError( - "block shape should either match output data shape along broadcasted dimension or non-broadcasted dimensions" - ) - elif len(block_size) < len(shape_out): - block_size = [-1] * (len(shape_out) - len(block_size)) + list(block_size) - else: + if broadcasting and block_size is not None: + if block_size[-n_dim_reproject:] == shape_out[-n_dim_reproject:]: + # TODO: maybe error if block_size was given in full and is wrong + broadcasted_parallelization = True + block_size = (1,) * (len(shape_out) - n_dim_reproject) + block_size[-n_dim_reproject:] + elif block_size[:-n_dim_reproject] != shape_out[:-n_dim_reproject]: raise ValueError( - f"block_size {len(block_size)} cannot have more elements " - f"than the dimensionality of the output ({len(shape_out)})" - ) - - # TODO: check for shape_out not matching shape_in along broadcasted dimensions + "block shape should either match output data shape along reprojected dimensions or non-reprojected dimensions" + ) logger.info( f"{'P' if broadcasted_parallelization else 'Not p'}arallelizing along " @@ -255,8 +289,6 @@ def _reproject_dispatcher( if output_footprint is None and return_footprint: output_footprint = np.zeros(shape_out, dtype=float) - shape_in = array_in.shape - def reproject_single_block(a, array_or_path, block_info=None): if ( @@ -270,6 +302,8 @@ def reproject_single_block(a, array_or_path, block_info=None): if isinstance(array_or_path, str) and array_or_path == "from-dict": array_or_path = dask_arrays["array"] + shape_out = block_info[None]["chunk-shape"][1:] + # The WCS class from astropy is not thread-safe, see e.g. # https://github.com/astropy/astropy/issues/16244 # https://github.com/astropy/astropy/issues/16245 @@ -281,16 +315,38 @@ def reproject_single_block(a, array_or_path, block_info=None): wcs_in_cp = wcs_in.deepcopy() if isinstance(wcs_in, WCS) else wcs_in wcs_out_cp = wcs_out.deepcopy() if isinstance(wcs_out, WCS) else wcs_out - slices = [ - slice(*x) for x in block_info[None]["array-location"][-wcs_out_cp.pixel_n_dim :] - ] + slices_in = [] + slices_out = [] + for idx in range(len(shape_out)): + interval = block_info[None]["array-location"][idx + 1] + if broadcasted_parallelization and idx < len(shape_out) - n_dim_reproject: + if interval[1] - interval[0] != 1: + raise RuntimeError(f"Expected a chunk of width 1 along dimension {idx} (got {interval[1] - interval[0]})") + slices_in.append(interval[0]) + slices_out.append(interval[0]) + else: + slices_in.append(slice(None)) + slices_out.append(slice(*block_info[None]["array-location"][idx + 1])) + + slices_in = slices_in[-wcs_in.pixel_n_dim] + slices_out = slices_out[-wcs_out.pixel_n_dim] + + if broadcasted_parallelization: + if isinstance(wcs_in_cp, BaseHighLevelWCS): + low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp.low_level_wcs, slices=slices_in) + else: + low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp, slices=slices_in) - if isinstance(wcs_out, BaseHighLevelWCS): - low_level_wcs = SlicedLowLevelWCS(wcs_out_cp.low_level_wcs, slices=slices) + wcs_in_sub = HighLevelWCSWrapper(low_level_wcs_in) else: - low_level_wcs = SlicedLowLevelWCS(wcs_out_cp, slices=slices) + wcs_in_sub = wcs_in_cp - wcs_out_sub = HighLevelWCSWrapper(low_level_wcs) + if isinstance(wcs_out_cp, BaseHighLevelWCS): + low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp.low_level_wcs, slices=slices_out) + else: + low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp, slices=slices_out) + + wcs_out_sub = HighLevelWCSWrapper(low_level_wcs_out) if isinstance(array_or_path, tuple): array_in = np.memmap(array_or_path[0], **array_or_path[1], mode="r") @@ -302,11 +358,9 @@ def reproject_single_block(a, array_or_path, block_info=None): if array_or_path is None: raise RuntimeError("array_or_path is not set") - shape_out = block_info[None]["chunk-shape"][1:] - array, footprint = reproject_func( array_in, - wcs_in_cp, + wcs_in_sub, wcs_out_sub, shape_out=shape_out, array_out=np.zeros(shape_out), @@ -319,12 +373,14 @@ def reproject_single_block(a, array_or_path, block_info=None): array_out_dask = da.empty(shape_out, chunks=block_size) if isinstance(array_in, da.core.Array): - if array_in.chunksize != block_size: - logger.info( - f"Rechunking input dask array as chunks ({array_in.chunksize}) " - "do not match block size ({block_size})" - ) - array_in = array_in.rechunk(block_size) + pass + # FIXME: Should take into account -1s here + # if array_in.chunksize != block_size: + # logger.info( + # f"Rechunking input dask array as chunks ({array_in.chunksize}) " + # f"do not match block size ({block_size})" + # ) + # array_in = array_in.rechunk(block_size) else: class ArrayWrapper: diff --git a/reproject/interpolation/core.py b/reproject/interpolation/core.py index d3021992c..105aac881 100644 --- a/reproject/interpolation/core.py +++ b/reproject/interpolation/core.py @@ -10,7 +10,7 @@ def _validate_wcs(wcs_in, wcs_out, shape_in, shape_out): if wcs_in.low_level_wcs.pixel_n_dim != wcs_out.low_level_wcs.pixel_n_dim: - raise ValueError("Number of dimensions in input and output WCS should match") + raise ValueError(f"Number of dimensions in input and output WCS should match (got {wcs_in.low_level_wcs.pixel_n_dim} and {wcs_out.low_level_wcs.pixel_n_dim})") elif len(shape_out) < wcs_out.low_level_wcs.pixel_n_dim: raise ValueError("Too few dimensions in shape_out") elif len(shape_in) < wcs_in.low_level_wcs.pixel_n_dim: diff --git a/reproject/interpolation/high_level.py b/reproject/interpolation/high_level.py index ed3ab89af..c7f3758e2 100644 --- a/reproject/interpolation/high_level.py +++ b/reproject/interpolation/high_level.py @@ -25,6 +25,7 @@ def reproject_interp( output_footprint=None, return_footprint=True, block_size=None, + non_reprojected_dims=None, parallel=False, return_type=None, dask_method=None, @@ -152,6 +153,7 @@ def reproject_interp( array_out=output_array, parallel=parallel, block_size=block_size, + non_reprojected_dims=non_reprojected_dims, return_footprint=return_footprint, output_footprint=output_footprint, reproject_func_kwargs=dict( diff --git a/reproject/mosaicking/coadd.py b/reproject/mosaicking/coadd.py index 0db6b3b08..461cf7dd5 100644 --- a/reproject/mosaicking/coadd.py +++ b/reproject/mosaicking/coadd.py @@ -243,10 +243,15 @@ def reproject_and_coadd( # convex in the output projection), and transforming every edge pixel, # which provides a lot of redundant information. - edges = sample_array_edges( - array_in.shape[-wcs_in.low_level_wcs.pixel_n_dim :], n_samples=11 - )[::-1] - edges_out = pixel_to_pixel(wcs_in, wcs_out, *edges)[::-1] + # TODO: ignore non-repreojected dims here and slice WCS + + try: + edges = sample_array_edges( + array_in.shape[-wcs_in.low_level_wcs.pixel_n_dim :], n_samples=11 + )[::-1] + edges_out = pixel_to_pixel(wcs_in, wcs_out, *edges)[::-1] + except: + edges_out = np.array([np.nan]) # Determine the cutout parameters @@ -257,7 +262,7 @@ def reproject_and_coadd( ndim_out = len(shape_out) # Determine how many extra broadcasted dimensions are present - n_broadcasted = len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim + n_broadcasted = len(shape_out) - wcs_out.low_level_wcs.pixel_n_dim skip_data = False if np.any(np.isnan(edges_out)):