From e0763cad48c1e3f1648232a7708ef03f2b6c22f8 Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Fri, 7 Nov 2025 14:12:16 -0800 Subject: [PATCH] Optimize the test setup PiperOrigin-RevId: 829573098 --- meridian/analysis/optimizer_test.py | 70 +++++++++++++++-------------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/meridian/analysis/optimizer_test.py b/meridian/analysis/optimizer_test.py index b3ef916ae..7a02113ee 100644 --- a/meridian/analysis/optimizer_test.py +++ b/meridian/analysis/optimizer_test.py @@ -281,6 +281,43 @@ def _get_sample_optimized_data(is_revenue_kpi: bool = True) -> xr.Dataset: class OptimizerAlgorithmTest(parameterized.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.inference_data_media_and_rf = az.InferenceData( + prior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_prior_media_and_rf.nc') + ), + posterior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_posterior_media_and_rf.nc') + ), + ) + cls.inference_data_media_only = az.InferenceData( + prior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_prior_media_only.nc') + ), + posterior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_posterior_media_only.nc') + ), + ) + cls.inference_data_rf_only = az.InferenceData( + prior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_prior_rf_only.nc') + ), + posterior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_posterior_rf_only.nc') + ), + ) + cls.inference_data_all_channels = az.InferenceData( + prior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_prior_non_paid.nc') + ), + posterior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_posterior_non_paid.nc') + ), + ) + # TODO: Update the sample datasets to span over 1 year. def setUp(self): super(OptimizerAlgorithmTest, self).setUp() @@ -331,39 +368,6 @@ def setUp(self): ) ) - self.inference_data_media_and_rf = az.InferenceData( - prior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_prior_media_and_rf.nc') - ), - posterior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_posterior_media_and_rf.nc') - ), - ) - self.inference_data_media_only = az.InferenceData( - prior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_prior_media_only.nc') - ), - posterior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_posterior_media_only.nc') - ), - ) - self.inference_data_rf_only = az.InferenceData( - prior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_prior_rf_only.nc') - ), - posterior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_posterior_rf_only.nc') - ), - ) - self.inference_data_all_channels = az.InferenceData( - prior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_prior_non_paid.nc') - ), - posterior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_posterior_non_paid.nc') - ), - ) - self.meridian_media_and_rf = model.Meridian( input_data=self.input_data_media_and_rf )