Skip to content

Commit 355baf0

Browse files
lukmazThe Meridian Authors
authored andcommitted
[refactor] Remove BudgetOptimizer.meridian field.
PiperOrigin-RevId: 857304790
1 parent ec7b1b0 commit 355baf0

File tree

2 files changed

+75
-51
lines changed

2 files changed

+75
-51
lines changed

meridian/analysis/optimizer.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def optimize(
235235
spend_constraint_lower = spend_constraint_default
236236
if spend_constraint_upper is None:
237237
spend_constraint_upper = spend_constraint_default
238-
(optimization_lower_bound, optimization_upper_bound) = (
238+
optimization_lower_bound, optimization_upper_bound = (
239239
get_optimization_bounds(
240240
n_channels=len(self.channels),
241241
spend=spend,
@@ -253,7 +253,7 @@ def optimize(
253253
' It is only a problem when you use a much smaller budget, '
254254
' for which the intended step size is smaller. '
255255
)
256-
(spend_grid, incremental_outcome_grid) = self.trim_grids(
256+
spend_grid, incremental_outcome_grid = self.trim_grids(
257257
spend_bound_lower=optimization_lower_bound,
258258
spend_bound_upper=optimization_upper_bound,
259259
)
@@ -925,7 +925,7 @@ def get_response_curves(self) -> xr.Dataset:
925925
"""
926926
channels = self.optimized_data.channel.values
927927
selected_times = _expand_selected_times(
928-
model_context=self.meridian.model_context,
928+
model_context=self.analyzer.model_context,
929929
start_date=self.optimized_data.start_date,
930930
end_date=self.optimized_data.end_date,
931931
new_data=self.new_data,
@@ -1065,7 +1065,9 @@ def _gen_optimization_summary(self, currency: str) -> str:
10651065
self.template_env.globals[c.START_DATE] = start_date.strftime(
10661066
f'%b {start_date.day}, %Y'
10671067
)
1068-
interval_days = self.meridian.input_data.time_coordinates.interval_days
1068+
interval_days = (
1069+
self.analyzer.model_context.input_data.time_coordinates.interval_days
1070+
)
10691071
end_date = tc.normalize_date(self.optimized_data.end_date)
10701072
end_date_adjusted = end_date + pd.Timedelta(days=interval_days)
10711073
self.template_env.globals[c.END_DATE] = end_date_adjusted.strftime(
@@ -1324,9 +1326,16 @@ class BudgetOptimizer:
13241326
results can be viewed as plots and as an HTML summary output page.
13251327
"""
13261328

1327-
def __init__(self, meridian: model.Meridian):
1329+
def __init__(
1330+
self,
1331+
meridian: model.Meridian,
1332+
):
13281333
self._meridian = meridian
1329-
self._analyzer = analyzer_module.Analyzer(self._meridian)
1334+
self._analyzer = analyzer_module.Analyzer(
1335+
model_context=meridian.model_context,
1336+
model_equations=meridian.model_equations,
1337+
inference_data=meridian.inference_data,
1338+
)
13301339

13311340
def _validate_model_fit(self, use_posterior: bool):
13321341
"""Validates that the model is fit."""
@@ -1929,7 +1938,7 @@ def _validate_grid(
19291938
pct_of_spend=pct_of_spend,
19301939
)
19311940
spend = budget * valid_pct_of_spend
1932-
(optimization_lower_bound, optimization_upper_bound) = (
1941+
optimization_lower_bound, optimization_upper_bound = (
19331942
get_optimization_bounds(
19341943
n_channels=n_channels,
19351944
spend=spend,
@@ -2108,7 +2117,7 @@ def create_optimization_grid(
21082117
)
21092118
spend = budget * valid_pct_of_spend
21102119
round_factor = get_round_factor(budget, gtol)
2111-
(optimization_lower_bound, optimization_upper_bound) = (
2120+
optimization_lower_bound, optimization_upper_bound = (
21122121
get_optimization_bounds(
21132122
n_channels=n_paid_channels,
21142123
spend=spend,
@@ -2138,7 +2147,7 @@ def create_optimization_grid(
21382147
optimal_frequency = None
21392148

21402149
step_size = 10 ** (-round_factor)
2141-
(spend_grid, incremental_outcome_grid) = self._create_grids(
2150+
spend_grid, incremental_outcome_grid = self._create_grids(
21422151
spend=hist_spend,
21432152
spend_bound_lower=optimization_lower_bound,
21442153
spend_bound_upper=optimization_upper_bound,
@@ -2324,13 +2333,11 @@ def _create_budget_dataset(
23242333
)
23252334
spend_tensor = backend.to_tensor(spend, dtype=backend.float32)
23262335
hist_spend = backend.to_tensor(hist_spend, dtype=backend.float32)
2327-
(new_media, new_reach, new_frequency) = (
2328-
self._get_incremental_outcome_tensors(
2329-
hist_spend,
2330-
spend_tensor,
2331-
new_data=filled_data.filter_fields(c.PAID_CHANNELS),
2332-
optimal_frequency=optimal_frequency,
2333-
)
2336+
new_media, new_reach, new_frequency = self._get_incremental_outcome_tensors(
2337+
hist_spend,
2338+
spend_tensor,
2339+
new_data=filled_data.filter_fields(c.PAID_CHANNELS),
2340+
optimal_frequency=optimal_frequency,
23342341
)
23352342
budget = np.sum(spend_tensor)
23362343
inc_outcome_data = analyzer_module.DataTensors(
@@ -2961,7 +2968,7 @@ def _get_spend_bounds(
29612968
spend_bounds: tuple of np.ndarray of size `n_total_channels` containing
29622969
the untreated lower and upper bound spend for each media and RF channel.
29632970
"""
2964-
(spend_const_lower, spend_const_upper) = _validate_spend_constraints(
2971+
spend_const_lower, spend_const_upper = _validate_spend_constraints(
29652972
n_channels,
29662973
spend_constraint_lower,
29672974
spend_constraint_upper,

meridian/analysis/optimizer_test.py

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -967,7 +967,7 @@ def test_incremental_outcome_tensors(self):
967967
hist_spend = backend.to_tensor(
968968
[350, 400, 200, 50, 500], dtype=backend.float32
969969
)
970-
(new_media, new_reach, new_frequency) = (
970+
new_media, new_reach, new_frequency = (
971971
self.budget_optimizer_media_and_rf._get_incremental_outcome_tensors(
972972
hist_spend, spend
973973
)
@@ -995,7 +995,7 @@ def test_incremental_outcome_tensors_with_optimal_frequency(self):
995995
[350, 400, 200, 50, 500], dtype=backend.float32
996996
)
997997
optimal_frequency = backend.to_tensor([2, 2], dtype=backend.float32)
998-
(new_media, new_reach, new_frequency) = (
998+
new_media, new_reach, new_frequency = (
999999
self.budget_optimizer_media_and_rf._get_incremental_outcome_tensors(
10001000
hist_spend=hist_spend,
10011001
spend=spend,
@@ -1537,7 +1537,7 @@ def test_get_round_factor_budget_raise_error(self):
15371537
self.budget_optimizer_media_and_rf.optimize(budget=-10_000)
15381538

15391539
def test_get_optimization_bounds_correct(self):
1540-
(lower_bound, upper_bound) = optimizer.get_optimization_bounds(
1540+
lower_bound, upper_bound = optimizer.get_optimization_bounds(
15411541
n_channels=5,
15421542
spend=np.array([10642.5, 22222.0, 33333.0, 44444.0, 55555.0]),
15431543
round_factor=-2,
@@ -2134,7 +2134,7 @@ def test_trim_grid(self):
21342134
),
21352135
),
21362136
)
2137-
(updated_spend, updated_incremental_outcome) = grid.trim_grids(
2137+
updated_spend, updated_incremental_outcome = grid.trim_grids(
21382138
spend_bound_lower=np.array([100, 100, 400, 0, 200]),
21392139
spend_bound_upper=np.array([400, 300, 400, 100, 300]),
21402140
)
@@ -2848,7 +2848,9 @@ def test_budget_data_with_specified_pct_of_spend(
28482848
)
28492849
expected_pct_of_spend = [0.1, 0.2, 0.3, 0.3, 0.1]
28502850

2851-
idata = self.budget_optimizer_media_and_rf._meridian.input_data
2851+
idata = (
2852+
self.budget_optimizer_media_and_rf._analyzer.model_context.input_data
2853+
)
28522854
paid_channels = list(idata.get_all_paid_channels())
28532855
pct_of_spend = idata.get_paid_channels_argument_builder()(**{
28542856
paid_channels[0]: 0.1,
@@ -2888,7 +2890,9 @@ def test_budget_data_with_new_data_with_specified_pct_of_spend(
28882890
)
28892891
expected_pct_of_spend = [0.1, 0.2, 0.3, 0.3, 0.1]
28902892

2891-
idata = self.budget_optimizer_media_and_rf._meridian.input_data
2893+
idata = (
2894+
self.budget_optimizer_media_and_rf._analyzer.model_context.input_data
2895+
)
28922896
paid_channels = list(idata.get_all_paid_channels())
28932897
pct_of_spend = idata.get_paid_channels_argument_builder()(**{
28942898
paid_channels[0]: 0.1,
@@ -3072,22 +3076,22 @@ def test_optimize_when_no_warning_raised_for_mroi_constraint(self):
30723076
self.assertEmpty(w_list, '\n'.join([str(w.message) for w in w_list]))
30733077

30743078
def test_get_response_curves_new_times_data_correct(self):
3075-
meridian = self.budget_optimizer_media_and_rf._meridian
3076-
max_lag = meridian.model_spec.max_lag
3079+
ctx = self.budget_optimizer_media_and_rf._analyzer.model_context
3080+
max_lag = ctx.model_spec.max_lag
30773081
n_new_times = 15
30783082
total_times = max_lag + n_new_times
3079-
new_data_end_date = meridian.input_data.time.values[-1]
3080-
selected_times_start_date = meridian.input_data.time.values[-n_new_times]
3081-
selected_times = meridian.input_data.time.values[-n_new_times:].tolist()
3083+
new_data_end_date = ctx.input_data.time.values[-1]
3084+
selected_times_start_date = ctx.input_data.time.values[-n_new_times]
3085+
selected_times = ctx.input_data.time.values[-n_new_times:].tolist()
30823086

3083-
new_data_times = meridian.input_data.time.values[-total_times:].tolist()
3087+
new_data_times = ctx.input_data.time.values[-total_times:].tolist()
30843088
new_data = analyzer.DataTensors(
3085-
media=meridian.media_tensors.media[..., -total_times:, :],
3086-
media_spend=meridian.media_tensors.media_spend[..., -total_times:, :],
3087-
reach=meridian.rf_tensors.reach[..., -total_times:, :],
3088-
frequency=meridian.rf_tensors.frequency[..., -total_times:, :],
3089-
rf_spend=meridian.rf_tensors.rf_spend[..., -total_times:, :],
3090-
revenue_per_kpi=meridian.revenue_per_kpi[..., -total_times:],
3089+
media=ctx.media_tensors.media[..., -total_times:, :],
3090+
media_spend=ctx.media_tensors.media_spend[..., -total_times:, :],
3091+
reach=ctx.rf_tensors.reach[..., -total_times:, :],
3092+
frequency=ctx.rf_tensors.frequency[..., -total_times:, :],
3093+
rf_spend=ctx.rf_tensors.rf_spend[..., -total_times:, :],
3094+
revenue_per_kpi=ctx.revenue_per_kpi[..., -total_times:],
30913095
time=new_data_times,
30923096
)
30933097

@@ -3131,7 +3135,7 @@ def test_get_response_curves_new_times_data_correct(self):
31313135
self.assertEqual(kwargs['selected_times'], selected_times)
31323136

31333137
def test_create_budget_dataset_selected_geos_correct(self):
3134-
selected_geos = self.budget_optimizer_media_and_rf._meridian.input_data.geo.values.tolist()[
3138+
selected_geos = self.budget_optimizer_media_and_rf._analyzer.model_context.input_data.geo.values.tolist()[
31353139
:2
31363140
]
31373141
with mock.patch.object(
@@ -3153,7 +3157,7 @@ def test_create_budget_dataset_selected_geos_correct(self):
31533157
)
31543158

31553159
def test_get_response_curves_selected_geos_correct(self):
3156-
selected_geos = self.budget_optimizer_media_and_rf._meridian.input_data.geo.values.tolist()[
3160+
selected_geos = self.budget_optimizer_media_and_rf._analyzer.model_context.input_data.geo.values.tolist()[
31573161
:2
31583162
]
31593163
with mock.patch.object(
@@ -3783,27 +3787,40 @@ def setUp(self):
37833787
mock_data_kpi_output = mock.create_autospec(
37843788
input_data.InputData, instance=True
37853789
)
3786-
meridian = mock.create_autospec(
3787-
model.Meridian, instance=True, input_data=mock_data
3790+
model_context = mock.create_autospec(
3791+
context.ModelContext, instance=True, input_data=mock_data
37883792
)
3789-
meridian_kpi_output = mock.create_autospec(
3790-
model.Meridian, instance=True, input_data=mock_data_kpi_output
3793+
model_context_kpi_output = mock.create_autospec(
3794+
context.ModelContext, instance=True, input_data=mock_data_kpi_output
37913795
)
37923796
n_times = 149
37933797
n_geos = 10
37943798
self.revenue_per_kpi = data_test_utils.constant_revenue_per_kpi(
37953799
n_geos=n_geos, n_times=n_times, value=1.0
37963800
)
3797-
meridian.input_data.kpi_type = c.REVENUE
3798-
meridian.input_data.revenue_per_kpi = self.revenue_per_kpi
3799-
meridian.input_data.time_coordinates.interval_days = 7
3800-
meridian_kpi_output.input_data.kpi_type = c.NON_REVENUE
3801-
meridian_kpi_output.input_data.revenue_per_kpi = None
3802-
meridian_kpi_output.input_data.time_coordinates.interval_days = 7
3801+
model_context.input_data.kpi_type = c.REVENUE
3802+
model_context.input_data.revenue_per_kpi = self.revenue_per_kpi
3803+
model_context.input_data.time_coordinates.interval_days = 7
3804+
model_context_kpi_output.input_data.kpi_type = c.NON_REVENUE
3805+
model_context_kpi_output.input_data.revenue_per_kpi = None
3806+
model_context_kpi_output.input_data.time_coordinates.interval_days = 7
38033807

3804-
self.budget_optimizer = optimizer.BudgetOptimizer(meridian)
3808+
meridian_mock = mock.create_autospec(
3809+
model.Meridian,
3810+
instance=True,
3811+
input_data=mock_data,
3812+
model_context=model_context,
3813+
)
3814+
meridian_kpi_output_mock = mock.create_autospec(
3815+
model.Meridian,
3816+
instance=True,
3817+
input_data=mock_data_kpi_output,
3818+
model_context=model_context_kpi_output,
3819+
3820+
)
3821+
self.budget_optimizer = optimizer.BudgetOptimizer(meridian_mock)
38053822
self.budget_optimizer_kpi_output = optimizer.BudgetOptimizer(
3806-
meridian_kpi_output
3823+
meridian_kpi_output_mock
38073824
)
38083825
self.optimization_grid = optimizer.OptimizationGrid(
38093826
_grid_dataset=mock.MagicMock(),
@@ -4192,7 +4209,7 @@ def test_output_scenario_plan_card_stats_text_with_euro_currency(self):
41924209
stats_section = analysis_test_utils.get_child_element(card, 'stats-section')
41934210
stats = stats_section.findall('stats')
41944211
self.assertLen(stats, 6)
4195-
(non_optimized_budget, optimized_budget, _, _, _, _) = stats
4212+
non_optimized_budget, optimized_budget, _, _, _, _ = stats
41964213

41974214
with self.subTest('non_optimized_budget'):
41984215
stat = analysis_test_utils.get_child_element(
@@ -4227,7 +4244,7 @@ def test_output_scenario_plan_card_stats_text_with_euro_currency(self):
42274244
)
42284245
stats_kpi = stats_section_kpi.findall('stats')
42294246
self.assertLen(stats_kpi, 6)
4230-
(_, _, non_optimized_cpik, optimized_cpik, _, _) = stats_kpi
4247+
_, _, non_optimized_cpik, optimized_cpik, _, _ = stats_kpi
42314248

42324249
with self.subTest('non_optimized_cpik'):
42334250
stat = analysis_test_utils.get_child_element(

0 commit comments

Comments
 (0)