Skip to content

Commit b8b8bcf

Browse files
committed
Updated obs model to use separate obs frac for onshore and offshore regions.
1 parent 1fec038 commit b8b8bcf

2 files changed

Lines changed: 62 additions & 32 deletions

File tree

sup3r/models/with_obs.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,51 @@
1414

1515

1616
class Sup3rGanWithObs(Sup3rGan):
17-
"""Sup3r GAN model which includes mid network observation fixing. This
17+
"""Sup3r GAN model which includes mid network observation fusion. This
1818
model is useful for when production runs will be over a domain for which
1919
observation data is available."""
2020

21-
def __init__(self, *args, obs_frac=None, loss_obs_weight=None, **kwargs):
21+
def __init__(
22+
self,
23+
*args,
24+
onshore_obs_frac=None,
25+
offshore_obs_frac=None,
26+
loss_obs_weight=None,
27+
**kwargs,
28+
):
2229
"""
2330
Initialize the Sup3rGanWithObs model.
2431
2532
Parameters
2633
----------
2734
args : list
2835
Positional args for ``Sup3rGan`` parent class.
29-
obs_frac : dict
30-
Fraction of the batch that should be "fixed" with observations.
31-
Should include ``spatial`` key and optionally ``time`` key if this
32-
is a spatiotemporal model. The values should correspond roughly to
33-
the fraction of the production domain for which observations are
34-
available (spatial) and the fraction of the full time period that
35-
these cover. For each batch a spatial frac will be selected by
36-
uniformly selecting from the range ``(0, obs_frac['spatial'])``
36+
onshore_obs_frac : dict
37+
Fraction of the batch that should be treated as onshore
38+
observations. Should include ``spatial`` key and optionally
39+
``time`` key if this is a spatiotemporal model. The values should
40+
correspond roughly to the fraction of the production domain for
41+
which onshore observations are available (spatial) and the fraction
42+
of the full time period that these cover. For each batch a spatial
43+
frac will be selected by uniformly selecting from the range ``(0,
44+
obs_frac['spatial'])``
45+
offshore_obs_frac : dict
46+
Same as ``onshore_obs_frac`` but for offshore observations.
47+
Offshore observations are frequently sparser than onshore
48+
observations.
3749
loss_obs_weight : float
3850
Value used to weight observation locations in extra content loss
3951
term. e.g. The new content loss will include ``obs_loss_weight *
4052
MAE(hi_res_gen[~obs_mask], hi_res_true[~obs_mask])``
4153
kwargs : dict
4254
Keyword arguments for the ``Sup3rGan`` parent class.
4355
"""
44-
self.obs_frac = {} if obs_frac is None else obs_frac
56+
self.onshore_obs_frac = (
57+
{} if onshore_obs_frac is None else onshore_obs_frac
58+
)
59+
self.offshore_obs_frac = (
60+
{} if offshore_obs_frac is None else offshore_obs_frac
61+
)
4562
self.loss_obs_weight = loss_obs_weight
4663
super().__init__(*args, **kwargs)
4764

@@ -51,12 +68,12 @@ def _get_loss_obs_comparison(self, hi_res_true, hi_res_gen, obs_mask):
5168
locations."""
5269

5370
hr_true = [
54-
hi_res_true[..., self.hr_out_features.index(f)]
71+
hi_res_true[..., self.hr_out_features.index(f.replace('_obs', ''))]
5572
for f in self.obs_features
5673
]
5774
hr_true = tf.stack(hr_true, axis=-1)
5875
hr_gen = [
59-
hi_res_gen[..., self.hr_out_features.index(f)]
76+
hi_res_gen[..., self.hr_out_features.index(f.replace('_obs', ''))]
6077
for f in self.obs_features
6178
]
6279
hr_gen = tf.stack(hr_gen, axis=-1)
@@ -65,23 +82,15 @@ def _get_loss_obs_comparison(self, hi_res_true, hi_res_gen, obs_mask):
6582
loss_non_obs = MeanAbsoluteError()(hr_true[obs_mask], hr_gen[obs_mask])
6683
return loss_obs, loss_non_obs
6784

68-
def _get_obs_mask(self, hi_res, spatial_frac=None, time_frac=None):
69-
"""Define observation mask for the current batch. This is done
70-
with a spatial mask and a temporal mask since often observation data
71-
might be very sparse spatially but cover most of the full time period
72-
for those locations."""
73-
spatial_frac = (
74-
self.obs_frac['spatial'] if spatial_frac is None else spatial_frac
75-
)
85+
def _get_obs_mask(self, hi_res, spatial_frac, time_frac=None):
86+
"""Get observation mask for a given spatial and temporal obs
87+
fraction."""
7688
obs_mask = RANDOM_GENERATOR.choice(
7789
[True, False],
7890
size=hi_res.shape[1:3],
7991
p=[1 - spatial_frac, spatial_frac],
8092
)
8193
if self.is_5d:
82-
time_frac = (
83-
self.obs_frac['time'] if time_frac is None else time_frac
84-
)
8594
sp_mask = obs_mask.copy()
8695
obs_mask = RANDOM_GENERATOR.choice(
8796
[True, False],
@@ -91,6 +100,30 @@ def _get_obs_mask(self, hi_res, spatial_frac=None, time_frac=None):
91100
obs_mask[sp_mask] = True
92101
return np.repeat(obs_mask[None, ...], hi_res.shape[0], axis=0)
93102

103+
def get_obs_mask(self, hi_res):
104+
"""Define observation mask for the current batch. This is done
105+
with a spatial mask and a temporal mask since often observation data
106+
might be very sparse spatially but cover most of the full time period
107+
for those locations. This is also divided between onshore and offshore
108+
regions"""
109+
on_sf = RANDOM_GENERATOR.uniform(
110+
low=0, high=self.onshore_obs_frac['spatial']
111+
)
112+
on_tf = self.onshore_obs_frac.get('time', None)
113+
off_tf = self.offshore_obs_frac.get('time', None)
114+
obs_mask = self._get_obs_mask(hi_res, on_sf, on_tf)
115+
if 'topography' in self.hr_exo_features and self.offshore_obs_frac:
116+
topo_idx = len(self.hr_out_features) + self.hr_exo_features.index(
117+
'topography'
118+
)
119+
topo = hi_res[..., topo_idx]
120+
off_sf = RANDOM_GENERATOR.uniform(
121+
low=0, high=self.offshore_obs_frac['spatial']
122+
)
123+
offshore_mask = self._get_obs_mask(hi_res, off_sf, off_tf)
124+
obs_mask = tf.where(topo > 0, obs_mask, offshore_mask)
125+
return obs_mask
126+
94127
@property
95128
def model_params(self):
96129
"""
@@ -101,19 +134,16 @@ def model_params(self):
101134
dict
102135
"""
103136
params = super().model_params
104-
params['obs_frac'] = self.obs_frac
137+
params['onshore_obs_frac'] = self.onshore_obs_frac
138+
params['offshore_obs_frac'] = self.offshore_obs_frac
105139
params['loss_obs_weight'] = self.loss_obs_weight
106140
return params
107141

108142
def get_hr_exo_input(self, hi_res_true):
109143
"""Mask high res data to act as sparse observation data. Add this to
110144
the standard high res exo input"""
111145
exo_data = super().get_hr_exo_input(hi_res_true)
112-
spatial_frac = RANDOM_GENERATOR.uniform(
113-
low=0, high=self.obs_frac['spatial']
114-
)
115-
time_frac = self.obs_frac.get('time', None)
116-
obs_mask = self._get_obs_mask(hi_res_true, spatial_frac, time_frac)
146+
obs_mask = self.get_obs_mask(hi_res_true)
117147
for feature in self.obs_features:
118148
# obs_features can include a _obs suffix to avoid name conflict
119149
# with fully gridded exo features

tests/training/test_train_conditioned_obs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def test_fixed_wind_obs(gen_config_with_concat_masked):
4747
model = Sup3rGanWithObs(
4848
gen_config_with_concat_masked(),
4949
pytest.S_FP_DISC,
50-
obs_frac={'spatial': 0.1},
50+
onshore_obs_frac={'spatial': 0.1},
5151
loss_obs_weight=0.1,
5252
learning_rate=1e-4,
5353
)
54-
test_mask = model._get_obs_mask(np.zeros((1, 20, 20, 1, 1)))
54+
test_mask = model.get_obs_mask(np.zeros((1, 20, 20, 1, 1)))
5555
frac = 1 - test_mask.sum() / test_mask.size
5656
assert np.abs(0.1 - frac) < test_mask.size / (2 * np.sqrt(test_mask.size))
5757
assert model.obs_features == ['u_10m', 'v_10m']

0 commit comments

Comments
 (0)