@@ -967,7 +967,7 @@ def test_incremental_outcome_tensors(self):
967967 hist_spend = backend .to_tensor (
968968 [350 , 400 , 200 , 50 , 500 ], dtype = backend .float32
969969 )
970- ( new_media , new_reach , new_frequency ) = (
970+ new_media , new_reach , new_frequency = (
971971 self .budget_optimizer_media_and_rf ._get_incremental_outcome_tensors (
972972 hist_spend , spend
973973 )
@@ -995,7 +995,7 @@ def test_incremental_outcome_tensors_with_optimal_frequency(self):
995995 [350 , 400 , 200 , 50 , 500 ], dtype = backend .float32
996996 )
997997 optimal_frequency = backend .to_tensor ([2 , 2 ], dtype = backend .float32 )
998- ( new_media , new_reach , new_frequency ) = (
998+ new_media , new_reach , new_frequency = (
999999 self .budget_optimizer_media_and_rf ._get_incremental_outcome_tensors (
10001000 hist_spend = hist_spend ,
10011001 spend = spend ,
@@ -1537,7 +1537,7 @@ def test_get_round_factor_budget_raise_error(self):
15371537 self .budget_optimizer_media_and_rf .optimize (budget = - 10_000 )
15381538
15391539 def test_get_optimization_bounds_correct (self ):
1540- ( lower_bound , upper_bound ) = optimizer .get_optimization_bounds (
1540+ lower_bound , upper_bound = optimizer .get_optimization_bounds (
15411541 n_channels = 5 ,
15421542 spend = np .array ([10642.5 , 22222.0 , 33333.0 , 44444.0 , 55555.0 ]),
15431543 round_factor = - 2 ,
@@ -2134,7 +2134,7 @@ def test_trim_grid(self):
21342134 ),
21352135 ),
21362136 )
2137- ( updated_spend , updated_incremental_outcome ) = grid .trim_grids (
2137+ updated_spend , updated_incremental_outcome = grid .trim_grids (
21382138 spend_bound_lower = np .array ([100 , 100 , 400 , 0 , 200 ]),
21392139 spend_bound_upper = np .array ([400 , 300 , 400 , 100 , 300 ]),
21402140 )
@@ -2848,7 +2848,9 @@ def test_budget_data_with_specified_pct_of_spend(
28482848 )
28492849 expected_pct_of_spend = [0.1 , 0.2 , 0.3 , 0.3 , 0.1 ]
28502850
2851- idata = self .budget_optimizer_media_and_rf ._meridian .input_data
2851+ idata = (
2852+ self .budget_optimizer_media_and_rf ._analyzer .model_context .input_data
2853+ )
28522854 paid_channels = list (idata .get_all_paid_channels ())
28532855 pct_of_spend = idata .get_paid_channels_argument_builder ()(** {
28542856 paid_channels [0 ]: 0.1 ,
@@ -2888,7 +2890,9 @@ def test_budget_data_with_new_data_with_specified_pct_of_spend(
28882890 )
28892891 expected_pct_of_spend = [0.1 , 0.2 , 0.3 , 0.3 , 0.1 ]
28902892
2891- idata = self .budget_optimizer_media_and_rf ._meridian .input_data
2893+ idata = (
2894+ self .budget_optimizer_media_and_rf ._analyzer .model_context .input_data
2895+ )
28922896 paid_channels = list (idata .get_all_paid_channels ())
28932897 pct_of_spend = idata .get_paid_channels_argument_builder ()(** {
28942898 paid_channels [0 ]: 0.1 ,
@@ -3072,22 +3076,22 @@ def test_optimize_when_no_warning_raised_for_mroi_constraint(self):
30723076 self .assertEmpty (w_list , '\n ' .join ([str (w .message ) for w in w_list ]))
30733077
30743078 def test_get_response_curves_new_times_data_correct (self ):
3075- meridian = self .budget_optimizer_media_and_rf ._meridian
3076- max_lag = meridian .model_spec .max_lag
3079+ ctx = self .budget_optimizer_media_and_rf ._analyzer . model_context
3080+ max_lag = ctx .model_spec .max_lag
30773081 n_new_times = 15
30783082 total_times = max_lag + n_new_times
3079- new_data_end_date = meridian .input_data .time .values [- 1 ]
3080- selected_times_start_date = meridian .input_data .time .values [- n_new_times ]
3081- selected_times = meridian .input_data .time .values [- n_new_times :].tolist ()
3083+ new_data_end_date = ctx .input_data .time .values [- 1 ]
3084+ selected_times_start_date = ctx .input_data .time .values [- n_new_times ]
3085+ selected_times = ctx .input_data .time .values [- n_new_times :].tolist ()
30823086
3083- new_data_times = meridian .input_data .time .values [- total_times :].tolist ()
3087+ new_data_times = ctx .input_data .time .values [- total_times :].tolist ()
30843088 new_data = analyzer .DataTensors (
3085- media = meridian .media_tensors .media [..., - total_times :, :],
3086- media_spend = meridian .media_tensors .media_spend [..., - total_times :, :],
3087- reach = meridian .rf_tensors .reach [..., - total_times :, :],
3088- frequency = meridian .rf_tensors .frequency [..., - total_times :, :],
3089- rf_spend = meridian .rf_tensors .rf_spend [..., - total_times :, :],
3090- revenue_per_kpi = meridian .revenue_per_kpi [..., - total_times :],
3089+ media = ctx .media_tensors .media [..., - total_times :, :],
3090+ media_spend = ctx .media_tensors .media_spend [..., - total_times :, :],
3091+ reach = ctx .rf_tensors .reach [..., - total_times :, :],
3092+ frequency = ctx .rf_tensors .frequency [..., - total_times :, :],
3093+ rf_spend = ctx .rf_tensors .rf_spend [..., - total_times :, :],
3094+ revenue_per_kpi = ctx .revenue_per_kpi [..., - total_times :],
30913095 time = new_data_times ,
30923096 )
30933097
@@ -3131,7 +3135,7 @@ def test_get_response_curves_new_times_data_correct(self):
31313135 self .assertEqual (kwargs ['selected_times' ], selected_times )
31323136
31333137 def test_create_budget_dataset_selected_geos_correct (self ):
3134- selected_geos = self .budget_optimizer_media_and_rf ._meridian .input_data .geo .values .tolist ()[
3138+ selected_geos = self .budget_optimizer_media_and_rf ._analyzer . model_context .input_data .geo .values .tolist ()[
31353139 :2
31363140 ]
31373141 with mock .patch .object (
@@ -3153,7 +3157,7 @@ def test_create_budget_dataset_selected_geos_correct(self):
31533157 )
31543158
31553159 def test_get_response_curves_selected_geos_correct (self ):
3156- selected_geos = self .budget_optimizer_media_and_rf ._meridian .input_data .geo .values .tolist ()[
3160+ selected_geos = self .budget_optimizer_media_and_rf ._analyzer . model_context .input_data .geo .values .tolist ()[
31573161 :2
31583162 ]
31593163 with mock .patch .object (
@@ -3783,27 +3787,40 @@ def setUp(self):
37833787 mock_data_kpi_output = mock .create_autospec (
37843788 input_data .InputData , instance = True
37853789 )
3786- meridian = mock .create_autospec (
3787- model . Meridian , instance = True , input_data = mock_data
3790+ model_context = mock .create_autospec (
3791+ context . ModelContext , instance = True , input_data = mock_data
37883792 )
3789- meridian_kpi_output = mock .create_autospec (
3790- model . Meridian , instance = True , input_data = mock_data_kpi_output
3793+ model_context_kpi_output = mock .create_autospec (
3794+ context . ModelContext , instance = True , input_data = mock_data_kpi_output
37913795 )
37923796 n_times = 149
37933797 n_geos = 10
37943798 self .revenue_per_kpi = data_test_utils .constant_revenue_per_kpi (
37953799 n_geos = n_geos , n_times = n_times , value = 1.0
37963800 )
3797- meridian .input_data .kpi_type = c .REVENUE
3798- meridian .input_data .revenue_per_kpi = self .revenue_per_kpi
3799- meridian .input_data .time_coordinates .interval_days = 7
3800- meridian_kpi_output .input_data .kpi_type = c .NON_REVENUE
3801- meridian_kpi_output .input_data .revenue_per_kpi = None
3802- meridian_kpi_output .input_data .time_coordinates .interval_days = 7
3801+ model_context .input_data .kpi_type = c .REVENUE
3802+ model_context .input_data .revenue_per_kpi = self .revenue_per_kpi
3803+ model_context .input_data .time_coordinates .interval_days = 7
3804+ model_context_kpi_output .input_data .kpi_type = c .NON_REVENUE
3805+ model_context_kpi_output .input_data .revenue_per_kpi = None
3806+ model_context_kpi_output .input_data .time_coordinates .interval_days = 7
38033807
3804- self .budget_optimizer = optimizer .BudgetOptimizer (meridian )
3808+ meridian_mock = mock .create_autospec (
3809+ model .Meridian ,
3810+ instance = True ,
3811+ input_data = mock_data ,
3812+ model_context = model_context ,
3813+ )
3814+ meridian_kpi_output_mock = mock .create_autospec (
3815+ model .Meridian ,
3816+ instance = True ,
3817+ input_data = mock_data_kpi_output ,
3818+ model_context = model_context_kpi_output ,
3819+
3820+ )
3821+ self .budget_optimizer = optimizer .BudgetOptimizer (meridian_mock )
38053822 self .budget_optimizer_kpi_output = optimizer .BudgetOptimizer (
3806- meridian_kpi_output
3823+ meridian_kpi_output_mock
38073824 )
38083825 self .optimization_grid = optimizer .OptimizationGrid (
38093826 _grid_dataset = mock .MagicMock (),
@@ -4192,7 +4209,7 @@ def test_output_scenario_plan_card_stats_text_with_euro_currency(self):
41924209 stats_section = analysis_test_utils .get_child_element (card , 'stats-section' )
41934210 stats = stats_section .findall ('stats' )
41944211 self .assertLen (stats , 6 )
4195- ( non_optimized_budget , optimized_budget , _ , _ , _ , _ ) = stats
4212+ non_optimized_budget , optimized_budget , _ , _ , _ , _ = stats
41964213
41974214 with self .subTest ('non_optimized_budget' ):
41984215 stat = analysis_test_utils .get_child_element (
@@ -4227,7 +4244,7 @@ def test_output_scenario_plan_card_stats_text_with_euro_currency(self):
42274244 )
42284245 stats_kpi = stats_section_kpi .findall ('stats' )
42294246 self .assertLen (stats_kpi , 6 )
4230- ( _ , _ , non_optimized_cpik , optimized_cpik , _ , _ ) = stats_kpi
4247+ _ , _ , non_optimized_cpik , optimized_cpik , _ , _ = stats_kpi
42314248
42324249 with self .subTest ('non_optimized_cpik' ):
42334250 stat = analysis_test_utils .get_child_element (
0 commit comments