Skip to content

Commit ee4869a

Browse files
committed
Simplify API by setting .selectors via .loc in __post_init__
1 parent 2be39aa commit ee4869a

File tree

2 files changed

+16
-62
lines changed

2 files changed

+16
-62
lines changed

src/skillmodels/constraints.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,18 @@ class FixedConstraintWithValue(om.FixedConstraint):
4141
"""Value to enforce on the parameter."""
4242

4343
def __post_init__(self) -> None:
44-
"""Validate that `loc` and `value` are not None."""
44+
"""Validate that `loc` and `value` are not None and derive `selector`."""
4545
if self.loc is None:
4646
msg = "loc must not be None"
4747
raise TypeError(msg)
4848
if self.value is None:
4949
msg = "value must not be None"
5050
raise TypeError(msg)
51+
object.__setattr__(
52+
self,
53+
"selector",
54+
functools.partial(select_by_loc, loc=self.loc),
55+
)
5156

5257

5358
def get_constraints(
@@ -178,22 +183,10 @@ def _get_normalization_constraints(
178183
for period in periods:
179184
for meas, normval in normalizations[factor].loadings[period].items():
180185
loc = ("loadings", period, meas, factor)
181-
constraints.append(
182-
FixedConstraintWithValue(
183-
selector=functools.partial(select_by_loc, loc=loc),
184-
loc=loc,
185-
value=normval,
186-
)
187-
)
186+
constraints.append(FixedConstraintWithValue(loc=loc, value=normval))
188187
for meas, normval in normalizations[factor].intercepts[period].items():
189188
loc = ("controls", period, meas, "constant")
190-
constraints.append(
191-
FixedConstraintWithValue(
192-
selector=functools.partial(select_by_loc, loc=loc),
193-
loc=loc,
194-
value=normval,
195-
)
196-
)
189+
constraints.append(FixedConstraintWithValue(loc=loc, value=normval))
197190

198191
return constraints
199192

@@ -205,11 +198,7 @@ def _get_mixture_weights_constraints(
205198
loc = "mixture_weights"
206199
if n_mixtures == 1:
207200
return [
208-
FixedConstraintWithValue(
209-
selector=functools.partial(select_by_loc, loc=loc),
210-
loc=loc,
211-
value=1.0,
212-
),
201+
FixedConstraintWithValue(loc=loc, value=1.0),
213202
]
214203
return [
215204
om.ProbabilityConstraint(selector=functools.partial(select_by_loc, loc=loc))
@@ -277,11 +266,7 @@ def _get_constant_factors_constraints(
277266
for aug_period in labels.aug_periods[:-1]:
278267
loc = ("shock_sds", aug_period, factor, "-")
279268
constraints.append(
280-
FixedConstraintWithValue(
281-
selector=functools.partial(select_by_loc, loc=loc),
282-
loc=loc,
283-
value=0.0,
284-
),
269+
FixedConstraintWithValue(loc=loc, value=0.0),
285270
)
286271
return constraints
287272

@@ -371,11 +356,7 @@ def _get_anchoring_constraints( # noqa: C901
371356
if locs:
372357
loc = tuple(locs)
373358
constraints.append(
374-
FixedConstraintWithValue(
375-
selector=functools.partial(select_by_loc, loc=loc),
376-
loc=loc,
377-
value=0,
378-
),
359+
FixedConstraintWithValue(loc=loc, value=0),
379360
)
380361

381362
if not anchoring_info.free_controls:
@@ -386,11 +367,7 @@ def _get_anchoring_constraints( # noqa: C901
386367
if ind_tups:
387368
loc = tuple(ind_tups)
388369
constraints.append(
389-
FixedConstraintWithValue(
390-
selector=functools.partial(select_by_loc, loc=loc),
391-
loc=loc,
392-
value=0,
393-
),
370+
FixedConstraintWithValue(loc=loc, value=0),
394371
)
395372

396373
if not anchoring_info.free_loadings:
@@ -404,11 +381,7 @@ def _get_anchoring_constraints( # noqa: C901
404381
if ind_tups:
405382
loc = tuple(ind_tups)
406383
constraints.append(
407-
FixedConstraintWithValue(
408-
selector=functools.partial(select_by_loc, loc=loc),
409-
loc=loc,
410-
value=1,
411-
),
384+
FixedConstraintWithValue(loc=loc, value=1),
412385
)
413386

414387
return constraints
@@ -467,7 +440,6 @@ def _get_constraints_for_augmented_periods(
467440
loc = ("shock_sds", aug_period, factor, "-")
468441
constraints.append(
469442
FixedConstraintWithValue(
470-
selector=functools.partial(select_by_loc, loc=loc),
471443
loc=loc,
472444
value=endogenous_factors_info.bounds_distance,
473445
)

src/skillmodels/transition_functions.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,7 @@ def identity_constraints_linear(
6666
for regressor in params_linear(all_factors):
6767
val = 1.0 if factor == regressor else 0.0
6868
loc = ("transition", aug_period, factor, regressor)
69-
constraints.append(
70-
FixedConstraintWithValue(
71-
selector=functools.partial(select_by_loc, loc=loc),
72-
loc=loc,
73-
value=val,
74-
)
75-
)
69+
constraints.append(FixedConstraintWithValue(loc=loc, value=val))
7670
return constraints
7771

7872

@@ -120,13 +114,7 @@ def identity_constraints_translog(
120114
for regressor in params_translog(all_factors):
121115
val = 1.0 if factor == regressor else 0.0
122116
loc = ("transition", aug_period, factor, regressor)
123-
constraints.append(
124-
FixedConstraintWithValue(
125-
selector=functools.partial(select_by_loc, loc=loc),
126-
loc=loc,
127-
value=val,
128-
)
129-
)
117+
constraints.append(FixedConstraintWithValue(loc=loc, value=val))
130118
return constraints
131119

132120

@@ -242,13 +230,7 @@ def identity_constraints_linear_and_squares(
242230
for regressor in params_linear_and_squares(all_factors):
243231
val = 1.0 if factor == regressor else 0.0
244232
loc = ("transition", aug_period, factor, regressor)
245-
constraints.append(
246-
FixedConstraintWithValue(
247-
selector=functools.partial(select_by_loc, loc=loc),
248-
loc=loc,
249-
value=val,
250-
)
251-
)
233+
constraints.append(FixedConstraintWithValue(loc=loc, value=val))
252234
return constraints
253235

254236

0 commit comments

Comments
 (0)