1414
1515"""Defines model specification parameters for Meridian."""
1616
17- from collections .abc import Mapping
17+ from collections .abc import Collection , Mapping
1818import dataclasses
1919from typing import Sequence
2020import warnings
21-
2221from meridian import constants
2322from meridian .model import prior_distribution
2423import numpy as np
@@ -166,17 +165,17 @@ class ModelSpec:
166165 given non_media treatments channel). If `None`, the minimum value is used
167166 as baseline for each non-media treatments channel. This attribute is used
168167 as the default value for the corresponding argument to `Analyzer` methods.
169- knots: An optional integer or list of integers indicating the knots used to
170- estimate time effects. When `knots` is a list of integers, the knot
171- locations are provided by that list. Zero corresponds to a knot at the
172- first time period, one corresponds to a knot at the second time period,
173- ..., and `(n_times - 1)` corresponds to a knot at the last time period).
174- Typically, we recommend including knots at `0` and `(n_times - 1)`, but
175- this is not required. When `knots` is an integer, then there are knots
176- with locations equally spaced across the time periods, (including knots at
177- zero and `(n_times - 1)`. When `knots` is` 1`, there is a single common
178- regression coefficient used for all time periods. If `knots` is set to
179- `None`, then the numbers of knots used is equal to the number of time
168+ knots: An optional integer or collection of integers indicating the knots
169+ used to estimate time effects. When `knots` is a collection of integers,
170+ the knot locations are provided by that list. Zero corresponds to a knot
171+ at the first time period, one corresponds to a knot at the second time
172+ period, ..., and `(n_times - 1)` corresponds to a knot at the last time
173+ period). Typically, we recommend including knots at `0` and `(n_times -
174+ 1)`, but this is not required. When `knots` is an integer, then there are
175+ knots with locations equally spaced across the time periods, (including
176+ knots at zero and `(n_times - 1)`. When `knots` is` 1`, there is a single
177+ common regression coefficient used for all time periods. If `knots` is set
178+ to `None`, then the numbers of knots used is equal to the number of time
180179 periods in the case of a geo model. This is equivalent to each time period
181180 having its own regression coefficient. If `knots` is set to `None` in the
182181 case of a national model, then the number of knots used is `1`. Default:
@@ -235,7 +234,7 @@ class ModelSpec:
235234 constants .TREATMENT_PRIOR_TYPE_CONTRIBUTION
236235 )
237236 non_media_baseline_values : Sequence [float | str ] | None = None
238- knots : int | list [int ] | None = None
237+ knots : int | Collection [int ] | None = None
239238 baseline_geo : int | str | None = None
240239 holdout_id : np .ndarray | None = None
241240 control_population_scaling_id : np .ndarray | None = None
@@ -321,6 +320,12 @@ def __post_init__(self):
321320 prior_type_name = "rf_prior_type" ,
322321 )
323322
323+ if isinstance (self .knots , Collection ):
324+ knots_list = list (self .knots )
325+ if not all (isinstance (x , (int , np .integer )) for x in knots_list ):
326+ raise ValueError ("`knots` must be a sequence of integers." )
327+ object .__setattr__ (self , "knots" , [int (x ) for x in knots_list ])
328+
324329 # Validate knots.
325330 if isinstance (self .knots , list ) and not self .knots :
326331 raise ValueError ("The `knots` parameter cannot be an empty list." )
@@ -330,6 +335,10 @@ def __post_init__(self):
330335 raise ValueError (
331336 "The `knots` parameter cannot be set when `enable_aks` is True."
332337 )
338+ if not (self .knots is None or isinstance (self .knots , (int , list ))):
339+ raise ValueError (
340+ f"Unsupported type for `knots` parameter: { type (self .knots )} ."
341+ )
333342
334343 @property
335344 def effective_media_prior_type (self ) -> str :
0 commit comments