Skip to content

Commit b737092

Browse files
committed
refactor: improve gradient handling in pcgrad function by optimizing gradient flattening and shuffling
1 parent bb3df0b commit b737092

File tree

3 files changed

+736
-10
lines changed

3 files changed

+736
-10
lines changed

sup3r/preprocessing/samplers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,11 @@ def _get_proxy_obs(self, hi_res):
454454
"""
455455
obs_mask = self._get_full_obs_mask(hi_res)
456456
obs = hi_res[..., self.obs_features_ind].copy()
457-
obs[obs_mask[..., : obs.shape[-1]]] = np.nan
458457
if self.perturbation_scale > 0:
459458
stdev = np.nanstd(obs, axis=(0, 1, 2, 3), keepdims=True)
460459
noise = np.random.uniform(-stdev, stdev)
461460
obs += self.perturbation_scale * noise
461+
obs[obs_mask[..., : obs.shape[-1]]] = np.nan
462462
return obs
463463

464464
def _append_obs_features(self, samples):

0 commit comments

Comments
 (0)