Skip to content

Commit

Permalink
Replace stacking gradient search with resample_blocks variant
Browse files Browse the repository at this point in the history
  • Loading branch information
mraspaud committed Oct 22, 2024
1 parent 9becb0d commit 3c07255
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 509 deletions.
243 changes: 10 additions & 233 deletions pyresample/gradient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,244 +53,19 @@ def GradientSearchResampler(source_geo_def, target_geo_def):

def create_gradient_search_resampler(source_geo_def, target_geo_def):
"""Create a gradient search resampler."""
if isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition):
if ((isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition)) or
(isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition))):
return ResampleBlocksGradientSearchResampler(source_geo_def, target_geo_def)
elif isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition):
return StackingGradientSearchResampler(source_geo_def, target_geo_def)
raise NotImplementedError


@da.as_gufunc(signature='(),()->(),()')
def transform(x_coords, y_coords, src_prj=None, dst_prj=None):
"""Calculate projection coordinates."""
transformer = pyproj.Transformer.from_crs(src_prj, dst_prj)
transformer = pyproj.Transformer.from_crs(src_prj, dst_prj, always_xy=True)
return transformer.transform(x_coords, y_coords)


class StackingGradientSearchResampler(BaseResampler):
"""Resample using gradient search based bilinear interpolation, using stacking for dask processing."""

def __init__(self, source_geo_def, target_geo_def):
"""Init GradientResampler."""
super().__init__(source_geo_def, target_geo_def)
import warnings
warnings.warn("You are using the Gradient Search Resampler, which is still EXPERIMENTAL.", stacklevel=2)
self.use_input_coords = None
self._src_dst_filtered = False
self.prj = None
self.src_x = None
self.src_y = None
self.src_slices = None
self.dst_x = None
self.dst_y = None
self.dst_slices = None
self.src_gradient_xl = None
self.src_gradient_xp = None
self.src_gradient_yl = None
self.src_gradient_yp = None
self.dst_polys = {}
self.dst_mosaic_locations = None
self.coverage_status = None

def _get_projection_coordinates(self, datachunks):
"""Get projection coordinates."""
if self.use_input_coords is None:
try:
self.src_x, self.src_y = self.source_geo_def.get_proj_coords(
chunks=datachunks)
src_crs = self.source_geo_def.crs
self.use_input_coords = True
except AttributeError:
self.src_x, self.src_y = self.source_geo_def.get_lonlats(
chunks=datachunks)
src_crs = pyproj.CRS.from_string("+proj=longlat")
self.use_input_coords = False
try:
self.dst_x, self.dst_y = self.target_geo_def.get_proj_coords(
chunks=CHUNK_SIZE)
dst_crs = self.target_geo_def.crs
except AttributeError as err:
if self.use_input_coords is False:
raise NotImplementedError('Cannot resample lon/lat to lon/lat with gradient search.') from err
self.dst_x, self.dst_y = self.target_geo_def.get_lonlats(
chunks=CHUNK_SIZE)
dst_crs = pyproj.CRS.from_string("+proj=longlat")
if self.use_input_coords:
self.dst_x, self.dst_y = transform(
self.dst_x, self.dst_y,
src_prj=dst_crs, dst_prj=src_crs)
self.prj = pyproj.Proj(self.source_geo_def.crs)
else:
self.src_x, self.src_y = transform(
self.src_x, self.src_y,
src_prj=src_crs, dst_prj=dst_crs)
self.prj = pyproj.Proj(self.target_geo_def.crs)

def _get_prj_poly(self, geo_def):
# - None if out of Earth Disk
# - False is SwathDefinition
if isinstance(geo_def, SwathDefinition):
return False
try:
poly = get_polygon(self.prj, geo_def)
except (NotImplementedError, ValueError): # out-of-earth disk area or any valid projected boundary coordinates
poly = None
return poly

def _get_src_poly(self, src_y_start, src_y_end, src_x_start, src_x_end):
"""Get bounding polygon for source chunk."""
geo_def = self.source_geo_def[src_y_start:src_y_end,
src_x_start:src_x_end]
return self._get_prj_poly(geo_def)

def _get_dst_poly(self, idx,
dst_x_start, dst_x_end,
dst_y_start, dst_y_end):
"""Get target chunk polygon."""
dst_poly = self.dst_polys.get(idx, None)
if dst_poly is None:
geo_def = self.target_geo_def[dst_y_start:dst_y_end,
dst_x_start:dst_x_end]
dst_poly = self._get_prj_poly(geo_def)
self.dst_polys[idx] = dst_poly
return dst_poly

def get_chunk_mappings(self):
"""Map source and target chunks together if they overlap."""
src_y_chunks, src_x_chunks = self.src_x.chunks
dst_y_chunks, dst_x_chunks = self.dst_x.chunks

coverage_status = []
src_slices, dst_slices = [], []
dst_mosaic_locations = []

src_x_start = 0
for src_x_step in src_x_chunks:
src_x_end = src_x_start + src_x_step
src_y_start = 0
for src_y_step in src_y_chunks:
src_y_end = src_y_start + src_y_step
# Get source chunk polygon
src_poly = self._get_src_poly(src_y_start, src_y_end,
src_x_start, src_x_end)

dst_x_start = 0
for x_step_number, dst_x_step in enumerate(dst_x_chunks):
dst_x_end = dst_x_start + dst_x_step
dst_y_start = 0
for y_step_number, dst_y_step in enumerate(dst_y_chunks):
dst_y_end = dst_y_start + dst_y_step
# Get destination chunk polygon
dst_poly = self._get_dst_poly((x_step_number, y_step_number),
dst_x_start, dst_x_end,
dst_y_start, dst_y_end)

covers = check_overlap(src_poly, dst_poly)

coverage_status.append(covers)
src_slices.append((src_y_start, src_y_end,
src_x_start, src_x_end))
dst_slices.append((dst_y_start, dst_y_end,
dst_x_start, dst_x_end))
dst_mosaic_locations.append((x_step_number, y_step_number))

dst_y_start = dst_y_end
dst_x_start = dst_x_end
src_y_start = src_y_end
src_x_start = src_x_end

self.src_slices = src_slices
self.dst_slices = dst_slices
self.dst_mosaic_locations = dst_mosaic_locations
self.coverage_status = coverage_status

def _filter_data(self, data, is_src=True, add_dim=False):
"""Filter unused chunks from the given array."""
if add_dim:
if data.ndim not in [2, 3]:
raise NotImplementedError('Gradient search resampling only '
'supports 2D or 3D arrays.')
if data.ndim == 2:
data = data[np.newaxis, :, :]

data_out = []
for i, covers in enumerate(self.coverage_status):
if covers:
if is_src:
y_start, y_end, x_start, x_end = self.src_slices[i]
else:
y_start, y_end, x_start, x_end = self.dst_slices[i]
try:
val = data[:, y_start:y_end, x_start:x_end]
except IndexError:
val = data[y_start:y_end, x_start:x_end]
else:
val = None
data_out.append(val)

return data_out

def _get_gradients(self):
"""Get gradients in X and Y directions."""
self.src_gradient_xl, self.src_gradient_xp = np.gradient(
self.src_x, axis=[0, 1])
self.src_gradient_yl, self.src_gradient_yp = np.gradient(
self.src_y, axis=[0, 1])

def _filter_src_dst(self):
"""Filter source and target chunks."""
self.src_x = self._filter_data(self.src_x)
self.src_y = self._filter_data(self.src_y)
self.src_gradient_yl = self._filter_data(self.src_gradient_yl)
self.src_gradient_yp = self._filter_data(self.src_gradient_yp)
self.src_gradient_xl = self._filter_data(self.src_gradient_xl)
self.src_gradient_xp = self._filter_data(self.src_gradient_xp)
self.dst_x = self._filter_data(self.dst_x, is_src=False)
self.dst_y = self._filter_data(self.dst_y, is_src=False)
self._src_dst_filtered = True

def compute(self, data, fill_value=None, **kwargs):
"""Resample the given data using gradient search algorithm."""
if 'bands' in data.dims:
datachunks = data.sel(bands=data.coords['bands'][0]).chunks
else:
datachunks = data.chunks
data_dims = data.dims
data_coords = data.coords

self._get_projection_coordinates(datachunks)

if self.src_gradient_xl is None:
self._get_gradients()
if self.coverage_status is None:
self.get_chunk_mappings()
if not self._src_dst_filtered:
self._filter_src_dst()

data = self._filter_data(data.data, add_dim=True)

res = parallel_gradient_search(data,
self.src_x, self.src_y,
self.dst_x, self.dst_y,
self.src_gradient_xl,
self.src_gradient_xp,
self.src_gradient_yl,
self.src_gradient_yp,
self.dst_mosaic_locations,
self.dst_slices,
**kwargs)

coords = _fill_in_coords(self.target_geo_def, data_coords, data_dims)

if fill_value is not None:
res = da.where(np.isnan(res), fill_value, res)
if res.ndim > len(data_dims):
res = res.squeeze()

res = xr.DataArray(res, dims=data_dims, coords=coords)
return res


def check_overlap(src_poly, dst_poly):
"""Check if the two polygons overlap."""
# swath definition case
Expand Down Expand Up @@ -491,8 +266,10 @@ def __init__(self, source_geo_def, target_geo_def):
"""Init GradientResampler."""
if isinstance(target_geo_def, SwathDefinition):
raise NotImplementedError("Cannot resample to a SwathDefinition.")
if isinstance(source_geo_def, SwathDefinition):
source_geo_def.lons = source_geo_def.lons.persist()
source_geo_def.lats = source_geo_def.lats.persist()
super().__init__(source_geo_def, target_geo_def)
logger.debug("/!\\ Instantiating an experimental GradientSearch resampler /!\\")
self.indices_xy = None

def precompute(self, **kwargs):
Expand Down Expand Up @@ -590,11 +367,11 @@ def gradient_resampler_indices(source_area, target_area, block_info=None, **kwar
def _get_coordinates_in_same_projection(source_area, target_area):
try:
src_x, src_y = source_area.get_proj_coords()
transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True)
except AttributeError as err:
raise NotImplementedError("Cannot resample from Swath for now.") from err

lons, lats = source_area.get_lonlats()
src_x, src_y = da.compute(lons, lats)
try:
transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True)
dst_x, dst_y = transformer.transform(*target_area.get_proj_coords())
except AttributeError as err:
raise NotImplementedError("Cannot resample to Swath for now.") from err
Expand All @@ -618,7 +395,7 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info=
res = ((1 - weight_l) * (1 - weight_p) * data[..., l_start, p_start] +
(1 - weight_l) * weight_p * data[..., l_start, p_end] +
weight_l * (1 - weight_p) * data[..., l_end, p_start] +
weight_l * weight_p * data[..., l_end, p_end])
weight_l * weight_p * data[..., l_end, p_end]).astype(data.dtype)
res = np.where(mask, fill_value, res)
return res

Expand Down
8 changes: 4 additions & 4 deletions pyresample/gradient/_gradient_search.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ cdef inline void bil(const data_type[:, :, :] data, int l0, int p0, float_index
p_b = min(p0 + 1, pmax)
w_p = dp
for i in range(z_size):
res[i] = ((1 - w_l) * (1 - w_p) * data[i, l_a, p_a] +
(1 - w_l) * w_p * data[i, l_a, p_b] +
w_l * (1 - w_p) * data[i, l_b, p_a] +
w_l * w_p * data[i, l_b, p_b])
res[i] = <data_type>((1 - w_l) * (1 - w_p) * data[i, l_a, p_a] +
(1 - w_l) * w_p * data[i, l_a, p_b] +
w_l * (1 - w_p) * data[i, l_b, p_a] +
w_l * w_p * data[i, l_b, p_b])


@cython.boundscheck(False)
Expand Down
7 changes: 3 additions & 4 deletions pyresample/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def resample_blocks(func, src_area, src_arrays, dst_area,
fill_value: Desired value for any invalid values in the output array
kwargs: any other keyword arguments that will be passed on to func.
Returns:
A dask array, chunked as dst_area, containing the resampled data.
Principle of operations:
Resample_blocks works by iterating over chunks on the dst_area domain. For each chunk, the corresponding slice
Expand All @@ -235,10 +238,6 @@ def resample_blocks(func, src_area, src_arrays, dst_area,
"""
if dst_area == src_area:
raise ValueError("Source and destination areas are identical."
" Should you be running `map_blocks` instead of `resample_blocks`?")

name = _create_dask_name(name, func,
src_area, src_arrays,
dst_area, dst_arrays,
Expand Down
Loading

0 comments on commit 3c07255

Please sign in to comment.