Skip to content

Commit 2c8d084

Browse files
authored
Merge pull request #265 from NREL/gb/bc_optm
Gb/bc optm
2 parents cbb7c8e + fd4179c commit 2c8d084

File tree

4 files changed

+33
-16
lines changed

4 files changed

+33
-16
lines changed

sup3r/bias/bias_transforms.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -575,10 +575,10 @@ def _apply_qdm(
575575
if no_trend
576576
else np.reshape(bias_fut_params, (-1, bias_fut_params.shape[-1]))
577577
)
578+
578579
# The distributions at this point, after selected the respective
579580
# time window with `window_idx`, are 3D (space, space, N-params)
580581
# Collapse 3D (space, space, N) into 2D (space**2, N)
581-
582582
QDM = QuantileDeltaMapping(
583583
params_oh=np.reshape(base_params, (-1, base_params.shape[-1])),
584584
params_mh=np.reshape(bias_params, (-1, bias_params.shape[-1])),
@@ -594,8 +594,12 @@ def _apply_qdm(
594594
# input 3D shape (spatial, spatial, temporal)
595595
# QDM expects input arr with shape (time, space)
596596
tmp = np.reshape(subset.data, (-1, subset.shape[-1])).T
597+
597598
# Apply QDM correction
599+
logger.info(f'Applying QDM to data with shape {tmp.shape}...')
598600
tmp = QDM(tmp, max_workers=max_workers)
601+
logger.info(f'Finished QDM on data shape {tmp.shape}!')
602+
599603
# Reorgnize array back from (time, space)
600604
# to (spatial, spatial, temporal)
601605
return np.reshape(tmp.T, subset.shape)
@@ -751,9 +755,13 @@ def local_qdm_bc(
751755
)
752756

753757
cfg = params['cfg']
754-
base_params = params['base']
755-
bias_params = params['bias']
756-
bias_fut_params = params.get('bias_fut', None)
758+
759+
# params as dask arrays slows down QDM by several orders of magnitude
760+
base_params = np.array(params['base'])
761+
bias_params = np.array(params['bias'])
762+
bias_fut_params = None
763+
if 'bias_fut' in params:
764+
bias_fut_params = np.array(params['bias_fut'])
757765

758766
if lr_padded_slice is not None:
759767
spatial_slice = (lr_padded_slice[0], lr_padded_slice[1])

sup3r/bias/utilities.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
get_date_range_kwargs,
1616
)
1717

18+
1819
logger = logging.getLogger(__name__)
1920

2021

@@ -249,8 +250,10 @@ def bias_correct_feature(
249250
Data corrected by the bias_correct_method ready for input to the
250251
forward pass through the generative model.
251252
"""
253+
252254
time_slice = _parse_time_slice(time_slice)
253255
data = input_handler[source_feature][..., time_slice]
256+
254257
lat_lon = input_handler.lat_lon
255258
if bc_method is not None:
256259
bc_method = getattr(sup3r.bias.bias_transforms, bc_method)
@@ -260,15 +263,15 @@ def bias_correct_feature(
260263
if 'date_range_kwargs' in signature(bc_method).parameters:
261264
ti = input_handler.time_index[time_slice]
262265
feature_kwargs['date_range_kwargs'] = get_date_range_kwargs(ti)
263-
if (
264-
'lr_padded_slice' in signature(bc_method).parameters
265-
and 'lr_padded_slice' not in feature_kwargs
266-
):
266+
267+
use_lrps = 'lr_padded_slice' in signature(bc_method).parameters
268+
need_lrps = 'lr_padded_slice' not in feature_kwargs
269+
if use_lrps and need_lrps:
267270
feature_kwargs['lr_padded_slice'] = None
268-
if (
269-
'temporal_avg' in signature(bc_method).parameters
270-
and 'temporal_avg' not in feature_kwargs
271-
):
271+
272+
use_tavg = 'temporal_avg' in signature(bc_method).parameters
273+
need_tavg = 'temporal_avg' not in feature_kwargs
274+
if use_tavg and need_tavg:
272275
msg = (
273276
'The kwarg "temporal_avg" was not provided in the bias '
274277
'correction kwargs but is present in the bias '
@@ -306,6 +309,7 @@ def bias_correct_features(
306309
"""
307310

308311
time_slice = _parse_time_slice(time_slice)
312+
309313
for feat in features:
310314
try:
311315
input_handler[feat][..., time_slice] = bias_correct_feature(
@@ -324,4 +328,5 @@ def bias_correct_features(
324328
)
325329
logger.exception(msg)
326330
raise RuntimeError(msg) from e
331+
327332
return input_handler

sup3r/pipeline/forward_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from sup3r.utilities import ModuleName
2626
from sup3r.utilities.cli import BaseCLI
27+
from sup3r.utilities.utilities import Timer
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -52,6 +53,7 @@ def __init__(self, strategy, node_index=0):
5253
node_index : int
5354
Index of node used to run forward pass
5455
"""
56+
self.timer = Timer()
5557
self.strategy = strategy
5658
self.model = get_model(strategy.model_class, strategy.model_kwargs)
5759
self.node_index = node_index
@@ -230,7 +232,8 @@ def run_generator(
230232
temp = cls._reshape_data_chunk(model, data_chunk, exo_data)
231233
data_chunk, exo_data, i_lr_t, i_lr_s = temp
232234
try:
233-
hi_res = model.generate(data_chunk, exogenous_data=exo_data)
235+
fun = Timer()(model.generate, log=True)
236+
hi_res = fun(data_chunk, exogenous_data=exo_data)
234237
except Exception as e:
235238
msg = 'Forward pass failed on chunk with shape {}.'.format(
236239
data_chunk.shape

sup3r/pipeline/strategy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,9 +504,10 @@ def prep_chunk_data(self, chunk_index=0):
504504
f'Bias correcting data for chunk_index={chunk_index}, '
505505
f'with shape={input_data.shape}'
506506
)
507-
input_data = self.timer(
508-
bias_correct_features, log=True, call_id=chunk_index
509-
)(
507+
fun = self.timer(
508+
bias_correct_features, log=True, call_id=chunk_index,
509+
)
510+
input_data = fun(
510511
features=list(self.bias_correct_kwargs),
511512
input_handler=input_data,
512513
bc_method=self.bias_correct_method,

0 commit comments

Comments
 (0)