Skip to content

Commit c0b5977

Browse files
ez96The Meridian Authors
authored andcommitted
cast knots to list if user tries to pass in result from get_knot_info()
PiperOrigin-RevId: 859909723
1 parent dcf117e commit c0b5977

File tree

2 files changed

+59
-14
lines changed

2 files changed

+59
-14
lines changed

meridian/model/spec.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
"""Defines model specification parameters for Meridian."""
1616

17-
from collections.abc import Mapping
17+
from collections.abc import Collection, Mapping
1818
import dataclasses
1919
from typing import Sequence
2020
import warnings
21-
2221
from meridian import constants
2322
from meridian.model import prior_distribution
2423
import numpy as np
@@ -166,17 +165,17 @@ class ModelSpec:
166165
given non_media treatments channel). If `None`, the minimum value is used
167166
as baseline for each non-media treatments channel. This attribute is used
168167
as the default value for the corresponding argument to `Analyzer` methods.
169-
knots: An optional integer or list of integers indicating the knots used to
170-
estimate time effects. When `knots` is a list of integers, the knot
171-
locations are provided by that list. Zero corresponds to a knot at the
172-
first time period, one corresponds to a knot at the second time period,
173-
..., and `(n_times - 1)` corresponds to a knot at the last time period).
174-
Typically, we recommend including knots at `0` and `(n_times - 1)`, but
175-
this is not required. When `knots` is an integer, then there are knots
176-
with locations equally spaced across the time periods, (including knots at
177-
zero and `(n_times - 1)`. When `knots` is` 1`, there is a single common
178-
regression coefficient used for all time periods. If `knots` is set to
179-
`None`, then the numbers of knots used is equal to the number of time
168+
knots: An optional integer or collection of integers indicating the knots
169+
used to estimate time effects. When `knots` is a collection of integers,
170+
the knot locations are provided by that list. Zero corresponds to a knot
171+
at the first time period, one corresponds to a knot at the second time
172+
period, ..., and `(n_times - 1)` corresponds to a knot at the last time
173+
period). Typically, we recommend including knots at `0` and `(n_times -
174+
1)`, but this is not required. When `knots` is an integer, then there are
175+
knots with locations equally spaced across the time periods, (including
176+
knots at zero and `(n_times - 1)`. When `knots` is` 1`, there is a single
177+
common regression coefficient used for all time periods. If `knots` is set
178+
to `None`, then the numbers of knots used is equal to the number of time
180179
periods in the case of a geo model. This is equivalent to each time period
181180
having its own regression coefficient. If `knots` is set to `None` in the
182181
case of a national model, then the number of knots used is `1`. Default:
@@ -235,7 +234,7 @@ class ModelSpec:
235234
constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION
236235
)
237236
non_media_baseline_values: Sequence[float | str] | None = None
238-
knots: int | list[int] | None = None
237+
knots: int | Collection[int] | None = None
239238
baseline_geo: int | str | None = None
240239
holdout_id: np.ndarray | None = None
241240
control_population_scaling_id: np.ndarray | None = None
@@ -321,6 +320,12 @@ def __post_init__(self):
321320
prior_type_name="rf_prior_type",
322321
)
323322

323+
if isinstance(self.knots, Collection):
324+
knots_list = list(self.knots)
325+
if not all(isinstance(x, (int, np.integer)) for x in knots_list):
326+
raise ValueError("`knots` must be a sequence of integers.")
327+
object.__setattr__(self, "knots", [int(x) for x in knots_list])
328+
324329
# Validate knots.
325330
if isinstance(self.knots, list) and not self.knots:
326331
raise ValueError("The `knots` parameter cannot be an empty list.")
@@ -330,6 +335,10 @@ def __post_init__(self):
330335
raise ValueError(
331336
"The `knots` parameter cannot be set when `enable_aks` is True."
332337
)
338+
if not (self.knots is None or isinstance(self.knots, (int, list))):
339+
raise ValueError(
340+
f"Unsupported type for `knots` parameter: {type(self.knots)}."
341+
)
333342

334343
@property
335344
def effective_media_prior_type(self) -> str:

meridian/model/spec_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,42 @@ def test_init_warns_with_only_paid_media_prior_type(self):
414414
with self.assertWarnsRegex(UserWarning, warning_message):
415415
spec.ModelSpec(paid_media_prior_type="roi")
416416

417+
@parameterized.named_parameters(
418+
("ndarray", np.array([2, 5, 8], dtype=int), [2, 5, 8]),
419+
("tuple", (2, 5, 8), [2, 5, 8]),
420+
("set", {2, 5, 8}, [2, 5, 8]),
421+
("list", [2, 5, 8], [2, 5, 8]),
422+
("dict_keys", {2: "a", 5: "b", 8: "c"}, [2, 5, 8]),
423+
)
424+
def test_spec_inits_knots_with_collection_converts_to_list(
425+
self, knots_input, expected
426+
):
427+
"""Tests that passing any collection for knots converts it to a list[int]."""
428+
model_spec = spec.ModelSpec(knots=knots_input)
429+
430+
self.assertIsInstance(model_spec.knots, list)
431+
self.assertCountEqual(model_spec.knots, expected)
432+
433+
@parameterized.named_parameters(
434+
("strings_list", ["a", "b"]),
435+
("strings_tuple", ("a", "b")),
436+
("floats_list", [1.1, 2.2]),
437+
("mixed_tuple", (1, "a")),
438+
)
439+
def test_spec_inits_knots_with_non_integers_fails(self, knots_input):
440+
"""Tests that collections containing non-integers raise ValueError."""
441+
with self.assertRaisesRegex(
442+
ValueError, "`knots` must be a sequence of integers"
443+
):
444+
spec.ModelSpec(knots=knots_input)
445+
446+
def test_spec_inits_knots_with_unsupported_type_fails(self):
447+
"""Tests that passing an unsupported type (e.g. dict) raises ValueError."""
448+
with self.assertRaisesRegex(
449+
ValueError, "Unsupported type for `knots` parameter"
450+
):
451+
spec.ModelSpec(knots=3.5) # pytype: disable=wrong-arg-types
452+
417453

418454
if __name__ == "__main__":
419455
absltest.main()

0 commit comments

Comments
 (0)