1414
1515
1616class Sup3rGanWithObs (Sup3rGan ):
17- """Sup3r GAN model which includes mid network observation fixing . This
17+ """Sup3r GAN model which includes mid network observation fusion . This
1818 model is useful for when production runs will be over a domain for which
1919 observation data is available."""
2020
21- def __init__ (self , * args , obs_frac = None , loss_obs_weight = None , ** kwargs ):
21+ def __init__ (
22+ self ,
23+ * args ,
24+ onshore_obs_frac = None ,
25+ offshore_obs_frac = None ,
26+ loss_obs_weight = None ,
27+ ** kwargs ,
28+ ):
2229 """
2330 Initialize the Sup3rGanWithObs model.
2431
2532 Parameters
2633 ----------
2734 args : list
2835 Positional args for ``Sup3rGan`` parent class.
29- obs_frac : dict
30- Fraction of the batch that should be "fixed" with observations.
31- Should include ``spatial`` key and optionally ``time`` key if this
32- is a spatiotemporal model. The values should correspond roughly to
33- the fraction of the production domain for which observations are
34- available (spatial) and the fraction of the full time period that
35- these cover. For each batch a spatial frac will be selected by
36- uniformly selecting from the range ``(0, obs_frac['spatial'])``
36+ onshore_obs_frac : dict
37+ Fraction of the batch that should be treated as onshore
38+ observations. Should include ``spatial`` key and optionally
39+ ``time`` key if this is a spatiotemporal model. The values should
40+ correspond roughly to the fraction of the production domain for
41+ which onshore observations are available (spatial) and the fraction
42+ of the full time period that these cover. For each batch a spatial
43+ frac will be selected by uniformly selecting from the range ``(0,
44+ obs_frac['spatial'])``
45+ offshore_obs_frac : dict
46+ Same as ``onshore_obs_frac`` but for offshore observations.
47+ Offshore observations are frequently sparser than onshore
48+ observations.
3749 loss_obs_weight : float
3850 Value used to weight observation locations in extra content loss
3951 term. e.g. The new content loss will include ``obs_loss_weight *
4052 MAE(hi_res_gen[~obs_mask], hi_res_true[~obs_mask])``
4153 kwargs : dict
4254 Keyword arguments for the ``Sup3rGan`` parent class.
4355 """
44- self .obs_frac = {} if obs_frac is None else obs_frac
56+ self .onshore_obs_frac = (
57+ {} if onshore_obs_frac is None else onshore_obs_frac
58+ )
59+ self .offshore_obs_frac = (
60+ {} if offshore_obs_frac is None else offshore_obs_frac
61+ )
4562 self .loss_obs_weight = loss_obs_weight
4663 super ().__init__ (* args , ** kwargs )
4764
@@ -51,12 +68,12 @@ def _get_loss_obs_comparison(self, hi_res_true, hi_res_gen, obs_mask):
5168 locations."""
5269
5370 hr_true = [
54- hi_res_true [..., self .hr_out_features .index (f )]
71+ hi_res_true [..., self .hr_out_features .index (f . replace ( '_obs' , '' ) )]
5572 for f in self .obs_features
5673 ]
5774 hr_true = tf .stack (hr_true , axis = - 1 )
5875 hr_gen = [
59- hi_res_gen [..., self .hr_out_features .index (f )]
76+ hi_res_gen [..., self .hr_out_features .index (f . replace ( '_obs' , '' ) )]
6077 for f in self .obs_features
6178 ]
6279 hr_gen = tf .stack (hr_gen , axis = - 1 )
@@ -65,23 +82,15 @@ def _get_loss_obs_comparison(self, hi_res_true, hi_res_gen, obs_mask):
6582 loss_non_obs = MeanAbsoluteError ()(hr_true [obs_mask ], hr_gen [obs_mask ])
6683 return loss_obs , loss_non_obs
6784
68- def _get_obs_mask (self , hi_res , spatial_frac = None , time_frac = None ):
69- """Define observation mask for the current batch. This is done
70- with a spatial mask and a temporal mask since often observation data
71- might be very sparse spatially but cover most of the full time period
72- for those locations."""
73- spatial_frac = (
74- self .obs_frac ['spatial' ] if spatial_frac is None else spatial_frac
75- )
85+ def _get_obs_mask (self , hi_res , spatial_frac , time_frac = None ):
86+ """Get observation mask for a given spatial and temporal obs
87+ fraction."""
7688 obs_mask = RANDOM_GENERATOR .choice (
7789 [True , False ],
7890 size = hi_res .shape [1 :3 ],
7991 p = [1 - spatial_frac , spatial_frac ],
8092 )
8193 if self .is_5d :
82- time_frac = (
83- self .obs_frac ['time' ] if time_frac is None else time_frac
84- )
8594 sp_mask = obs_mask .copy ()
8695 obs_mask = RANDOM_GENERATOR .choice (
8796 [True , False ],
@@ -91,6 +100,30 @@ def _get_obs_mask(self, hi_res, spatial_frac=None, time_frac=None):
91100 obs_mask [sp_mask ] = True
92101 return np .repeat (obs_mask [None , ...], hi_res .shape [0 ], axis = 0 )
93102
103+ def get_obs_mask (self , hi_res ):
104+ """Define observation mask for the current batch. This is done
105+ with a spatial mask and a temporal mask since often observation data
106+ might be very sparse spatially but cover most of the full time period
107+ for those locations. This is also divided between onshore and offshore
108+ regions"""
109+ on_sf = RANDOM_GENERATOR .uniform (
110+ low = 0 , high = self .onshore_obs_frac ['spatial' ]
111+ )
112+ on_tf = self .onshore_obs_frac .get ('time' , None )
113+ off_tf = self .offshore_obs_frac .get ('time' , None )
114+ obs_mask = self ._get_obs_mask (hi_res , on_sf , on_tf )
115+ if 'topography' in self .hr_exo_features and self .offshore_obs_frac :
116+ topo_idx = len (self .hr_out_features ) + self .hr_exo_features .index (
117+ 'topography'
118+ )
119+ topo = hi_res [..., topo_idx ]
120+ off_sf = RANDOM_GENERATOR .uniform (
121+ low = 0 , high = self .offshore_obs_frac ['spatial' ]
122+ )
123+ offshore_mask = self ._get_obs_mask (hi_res , off_sf , off_tf )
124+ obs_mask = tf .where (topo > 0 , obs_mask , offshore_mask )
125+ return obs_mask
126+
94127 @property
95128 def model_params (self ):
96129 """
@@ -101,19 +134,16 @@ def model_params(self):
101134 dict
102135 """
103136 params = super ().model_params
104- params ['obs_frac' ] = self .obs_frac
137+ params ['onshore_obs_frac' ] = self .onshore_obs_frac
138+ params ['offshore_obs_frac' ] = self .offshore_obs_frac
105139 params ['loss_obs_weight' ] = self .loss_obs_weight
106140 return params
107141
108142 def get_hr_exo_input (self , hi_res_true ):
109143 """Mask high res data to act as sparse observation data. Add this to
110144 the standard high res exo input"""
111145 exo_data = super ().get_hr_exo_input (hi_res_true )
112- spatial_frac = RANDOM_GENERATOR .uniform (
113- low = 0 , high = self .obs_frac ['spatial' ]
114- )
115- time_frac = self .obs_frac .get ('time' , None )
116- obs_mask = self ._get_obs_mask (hi_res_true , spatial_frac , time_frac )
146+ obs_mask = self .get_obs_mask (hi_res_true )
117147 for feature in self .obs_features :
118148 # obs_features can include a _obs suffix to avoid name conflict
119149 # with fully gridded exo features
0 commit comments