diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ac21920d..396f817e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +* Add model fitting guardrail using EDA to `Meridian`. * Introduce serde package: a serialization and deserialization library for Meridian model with a protocol buffer schema. * Add model quality checks in the `analysis.review` module. diff --git a/meridian/analysis/optimizer_test.py b/meridian/analysis/optimizer_test.py index b3ef916ae..7a02113ee 100644 --- a/meridian/analysis/optimizer_test.py +++ b/meridian/analysis/optimizer_test.py @@ -281,6 +281,43 @@ def _get_sample_optimized_data(is_revenue_kpi: bool = True) -> xr.Dataset: class OptimizerAlgorithmTest(parameterized.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.inference_data_media_and_rf = az.InferenceData( + prior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_prior_media_and_rf.nc') + ), + posterior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_posterior_media_and_rf.nc') + ), + ) + cls.inference_data_media_only = az.InferenceData( + prior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_prior_media_only.nc') + ), + posterior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_posterior_media_only.nc') + ), + ) + cls.inference_data_rf_only = az.InferenceData( + prior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_prior_rf_only.nc') + ), + posterior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_posterior_rf_only.nc') + ), + ) + cls.inference_data_all_channels = az.InferenceData( + prior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_prior_non_paid.nc') + ), + posterior=xr.open_dataset( + os.path.join(_TEST_DATA_DIR, 'sample_posterior_non_paid.nc') + ), + ) + # TODO: Update the sample datasets to span over 1 year. def setUp(self): super(OptimizerAlgorithmTest, self).setUp() @@ -331,39 +368,6 @@ def setUp(self): ) ) - self.inference_data_media_and_rf = az.InferenceData( - prior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_prior_media_and_rf.nc') - ), - posterior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_posterior_media_and_rf.nc') - ), - ) - self.inference_data_media_only = az.InferenceData( - prior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_prior_media_only.nc') - ), - posterior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_posterior_media_only.nc') - ), - ) - self.inference_data_rf_only = az.InferenceData( - prior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_prior_rf_only.nc') - ), - posterior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_posterior_rf_only.nc') - ), - ) - self.inference_data_all_channels = az.InferenceData( - prior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_prior_non_paid.nc') - ), - posterior=xr.open_dataset( - os.path.join(_TEST_DATA_DIR, 'sample_posterior_non_paid.nc') - ), - ) - self.meridian_media_and_rf = model.Meridian( input_data=self.input_data_media_and_rf ) diff --git a/meridian/model/eda/eda_engine.py b/meridian/model/eda/eda_engine.py index fed5f33f8..736fe188c 100644 --- a/meridian/model/eda/eda_engine.py +++ b/meridian/model/eda/eda_engine.py @@ -19,7 +19,7 @@ import dataclasses import functools import typing -from typing import Optional, Sequence +from typing import Optional, Protocol, Sequence from meridian import backend from meridian import constants @@ -62,6 +62,15 @@ _VIF_COL_NAME = 'VIF' +class _NamedEDACheckCallable(Protocol): + """A callable that returns an EDAOutcome and has a __name__ attribute.""" + + __name__: str + + def __call__(self) -> eda_outcome.EDAOutcome: + ... + + class GeoLevelCheckOnNationalModelError(Exception): """Raised when a geo-level check is called on a national model.""" @@ -339,6 +348,10 @@ def __init__( def spec(self) -> eda_spec.EDASpec: return self._spec + @property + def _is_national_data(self) -> bool: + return self._meridian.is_national + @functools.cached_property def controls_scaled_da(self) -> xr.DataArray | None: if self._meridian.input_data.controls is None: @@ -355,7 +368,7 @@ def national_controls_scaled_da(self) -> xr.DataArray | None: """Returns the national scaled controls data array.""" if self._meridian.input_data.controls is None: return None - if self._meridian.is_national: + if self._is_national_data: if self.controls_scaled_da is None: # This case should be impossible given the check above. raise RuntimeError( @@ -412,7 +425,7 @@ def national_media_spend_da(self) -> xr.DataArray | None: """Returns the national media spend data array.""" if self.media_spend_da is None: return None - if self._meridian.is_national: + if self._is_national_data: national_da = self.media_spend_da.squeeze(constants.GEO, drop=True) national_da.name = constants.NATIONAL_MEDIA_SPEND else: @@ -428,7 +441,7 @@ def national_media_raw_da(self) -> xr.DataArray | None: """Returns the national raw media data array.""" if self.media_raw_da is None: return None - if self._meridian.is_national: + if self._is_national_data: national_da = self.media_raw_da.squeeze(constants.GEO, drop=True) national_da.name = constants.NATIONAL_MEDIA else: @@ -445,7 +458,7 @@ def national_media_scaled_da(self) -> xr.DataArray | None: """Returns the national scaled media data array.""" if self.media_scaled_da is None: return None - if self._meridian.is_national: + if self._is_national_data: national_da = self.media_scaled_da.squeeze(constants.GEO, drop=True) national_da.name = constants.NATIONAL_MEDIA_SCALED else: @@ -483,7 +496,7 @@ def national_organic_media_raw_da(self) -> xr.DataArray | None: """Returns the national raw organic media data array.""" if self.organic_media_raw_da is None: return None - if self._meridian.is_national: + if self._is_national_data: national_da = self.organic_media_raw_da.squeeze(constants.GEO, drop=True) national_da.name = constants.NATIONAL_ORGANIC_MEDIA else: @@ -498,7 +511,7 @@ def national_organic_media_scaled_da(self) -> xr.DataArray | None: """Returns the national scaled organic media data array.""" if self.organic_media_scaled_da is None: return None - if self._meridian.is_national: + if self._is_national_data: national_da = self.organic_media_scaled_da.squeeze( constants.GEO, drop=True ) @@ -528,7 +541,7 @@ def national_non_media_scaled_da(self) -> xr.DataArray | None: """Returns the national scaled non-media treatment data array.""" if self._meridian.input_data.non_media_treatments is None: return None - if self._meridian.is_national: + if self._is_national_data: if self.non_media_scaled_da is None: # This case should be impossible given the check above. raise RuntimeError( @@ -565,7 +578,7 @@ def national_rf_spend_da(self) -> xr.DataArray | None: """Returns the national RF spend data array.""" if self.rf_spend_da is None: return None - if self._meridian.is_national: + if self._is_national_data: national_da = self.rf_spend_da.squeeze(constants.GEO, drop=True) national_da.name = constants.NATIONAL_RF_SPEND else: @@ -728,7 +741,7 @@ def national_organic_rf_impressions_raw_da(self) -> xr.DataArray | None: @functools.cached_property def geo_population_da(self) -> xr.DataArray | None: - if self._meridian.is_national: + if self._is_national_data: return None return xr.DataArray( self._meridian.population, @@ -760,7 +773,7 @@ def _overall_scaled_kpi_invariability_artifact( @functools.cached_property def national_kpi_scaled_da(self) -> xr.DataArray: """Returns the national scaled KPI data array.""" - if self._meridian.is_national: + if self._is_national_data: national_da = self.kpi_scaled_da.squeeze(constants.GEO, drop=True) national_da.name = constants.NATIONAL_KPI_SCALED else: @@ -934,6 +947,24 @@ def national_all_freq_da(self) -> xr.DataArray | None: da.name = constants.NATIONAL_ALL_FREQUENCY return da + @property + def _critical_checks( + self, + ) -> list[tuple[_NamedEDACheckCallable, eda_outcome.EDACheckType]]: + """Returns a list of critical checks to be performed.""" + checks = [ + ( + self.check_overall_kpi_invariability, + eda_outcome.EDACheckType.KPI_INVARIABILITY, + ), + (self.check_vif, eda_outcome.EDACheckType.MULTICOLLINEARITY), + ( + self.check_pairwise_corr, + eda_outcome.EDACheckType.PAIRWISE_CORRELATION, + ), + ] + return checks + def _truncate_media_time(self, da: xr.DataArray) -> xr.DataArray: """Truncates the first `start` elements of the media time of a variable.""" # This should not happen. If it does, it means this function is mis-used. @@ -1097,7 +1128,7 @@ def _get_rf_data( impressions_raw_da.values, dtype=backend.float32 ) - if self._meridian.is_national: + if self._is_national_data: national_reach_raw_da = reach_raw_da.squeeze(constants.GEO, drop=True) national_reach_raw_da.name = names.national_reach national_reach_scaled_da = reach_scaled_da.squeeze( @@ -1199,7 +1230,7 @@ def check_geo_pairwise_corr( GeoLevelCheckOnNationalModelError: If the model is national. """ # If the model is national, raise an error. - if self._meridian.is_national: + if self._is_national_data: raise GeoLevelCheckOnNationalModelError( 'check_geo_pairwise_corr is not supported for national models.' ) @@ -1281,7 +1312,7 @@ def check_geo_pairwise_corr( ] return eda_outcome.EDAOutcome( - check_type=eda_outcome.EDACheckType.PAIRWISE_CORR, + check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION, findings=findings, analysis_artifacts=pairwise_corr_artifacts, ) @@ -1338,11 +1369,24 @@ def check_national_pairwise_corr( ) ] return eda_outcome.EDAOutcome( - check_type=eda_outcome.EDACheckType.PAIRWISE_CORR, + check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION, findings=findings, analysis_artifacts=pairwise_corr_artifacts, ) + def check_pairwise_corr( + self, + ) -> eda_outcome.EDAOutcome[eda_outcome.PairwiseCorrArtifact]: + """Checks pairwise correlation among treatments and controls. + + Returns: + An EDAOutcome object with findings and result values. + """ + if self._is_national_data: + return self.check_national_pairwise_corr() + else: + return self.check_geo_pairwise_corr() + def _check_std( self, data: xr.DataArray, @@ -1374,7 +1418,7 @@ def check_geo_std( self, ) -> eda_outcome.EDAOutcome[eda_outcome.StandardDeviationArtifact]: """Checks std for geo-level KPI, treatments, R&F, and controls.""" - if self._meridian.is_national: + if self._is_national_data: raise ValueError('check_geo_std is not applicable for national models.') findings = [] @@ -1447,7 +1491,7 @@ def check_geo_std( ) return eda_outcome.EDAOutcome( - check_type=eda_outcome.EDACheckType.STD, + check_type=eda_outcome.EDACheckType.STANDARD_DEVIATION, findings=findings, analysis_artifacts=artifacts, ) @@ -1526,14 +1570,27 @@ def check_national_std( ) return eda_outcome.EDAOutcome( - check_type=eda_outcome.EDACheckType.STD, + check_type=eda_outcome.EDACheckType.STANDARD_DEVIATION, findings=findings, analysis_artifacts=artifacts, ) + def check_std( + self, + ) -> eda_outcome.EDAOutcome[eda_outcome.StandardDeviationArtifact]: + """Checks standard deviation for treatments and controls. + + Returns: + An EDAOutcome object with findings and result values. + """ + if self._is_national_data: + return self.check_national_std() + else: + return self.check_geo_std() + def check_geo_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]: """Computes geo-level variance inflation factor among treatments and controls.""" - if self._meridian.is_national: + if self._is_national_data: raise ValueError( 'Geo-level VIF checks are not applicable for national models.' ) @@ -1614,7 +1671,7 @@ def check_geo_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]: ) return eda_outcome.EDAOutcome( - check_type=eda_outcome.EDACheckType.VIF, + check_type=eda_outcome.EDACheckType.MULTICOLLINEARITY, findings=findings, analysis_artifacts=[overall_vif_artifact, geo_vif_artifact], ) @@ -1667,11 +1724,22 @@ def check_national_vif( ) ) return eda_outcome.EDAOutcome( - check_type=eda_outcome.EDACheckType.VIF, + check_type=eda_outcome.EDACheckType.MULTICOLLINEARITY, findings=findings, analysis_artifacts=[national_vif_artifact], ) + def check_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]: + """Computes variance inflation factor among treatments and controls. + + Returns: + An EDAOutcome object with findings and result values. + """ + if self._is_national_data: + return self.check_national_vif() + else: + return self.check_geo_vif() + @property def kpi_has_variability(self) -> bool: """Returns True if the KPI has variability across geos and times.""" @@ -1683,7 +1751,7 @@ def kpi_has_variability(self) -> bool: def check_overall_kpi_invariability(self) -> eda_outcome.EDAOutcome: """Checks if the KPI is constant across all geos and times.""" kpi = self._overall_scaled_kpi_invariability_artifact.kpi_da.name - geo_text = '' if self._meridian.is_national else 'geos and ' + geo_text = '' if self._is_national_data else 'geos and ' if not self.kpi_has_variability: eda_finding = eda_outcome.EDAFinding( @@ -1706,3 +1774,29 @@ def check_overall_kpi_invariability(self) -> eda_outcome.EDAOutcome: findings=[eda_finding], analysis_artifacts=[self._overall_scaled_kpi_invariability_artifact], ) + + def run_all_critical_checks(self) -> list[eda_outcome.EDAOutcome]: + """Runs all critical EDA checks. + + Critical checks are those that can result in EDASeverity.ERROR findings. + + Returns: + A list of EDA outcomes, one for each check. + """ + outcomes = [] + for check, check_type in self._critical_checks: + try: + outcomes.append(check()) + except Exception as e: # pylint: disable=broad-except + error_finding = eda_outcome.EDAFinding( + severity=eda_outcome.EDASeverity.ERROR, + explanation=f'An error occurred during check {check.__name__}: {e}', + ) + outcomes.append( + eda_outcome.EDAOutcome( + check_type=check_type, + findings=[error_finding], + analysis_artifacts=[], + ) + ) + return outcomes diff --git a/meridian/model/eda/eda_engine_test.py b/meridian/model/eda/eda_engine_test.py index 816dbc72d..3a2497e02 100644 --- a/meridian/model/eda/eda_engine_test.py +++ b/meridian/model/eda/eda_engine_test.py @@ -206,6 +206,35 @@ def setUp(self): * self.mock_scale_factor ) + def _mock_critical_checks( + self, mock_results: dict[str, eda_outcome.EDAOutcome | Exception] + ): + """Mocks critical EDA checks with specified return values or exceptions.""" + for check_name, result in mock_results.items(): + patcher = mock.patch.object( + eda_engine.EDAEngine, check_name, autospec=True + ) + mock_check = self.enter_context(patcher) + if isinstance(result, Exception): + mock_check.side_effect = result + else: + mock_check.return_value = result + + def _create_eda_outcome( + self, + check_type: eda_outcome.EDACheckType, + severity: eda_outcome.EDASeverity, + ) -> eda_outcome.EDAOutcome: + """Creates an EDAOutcome with a single finding.""" + explanation = f"{check_type.name}: {severity.name}" + return eda_outcome.EDAOutcome( + check_type=check_type, + findings=[ + eda_outcome.EDAFinding(severity=severity, explanation=explanation) + ], + analysis_artifacts=[], + ) + def _mock_eda_engine_property(self, property_name, return_value): self.enter_context( mock.patch.object( @@ -3346,7 +3375,9 @@ def test_check_geo_pairwise_corr_one_error(self): self._mock_eda_engine_property("treatment_control_scaled_ds", mock_ds) outcome = engine.check_geo_pairwise_corr() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORR) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) self.assertLen(outcome.findings, 1) self.assertLen(outcome.analysis_artifacts, 2) @@ -3391,7 +3422,9 @@ def test_check_geo_pairwise_corr_one_attention(self): self._mock_eda_engine_property("treatment_control_scaled_ds", mock_ds) outcome = engine.check_geo_pairwise_corr() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORR) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) self.assertLen(outcome.findings, 1) self.assertLen(outcome.analysis_artifacts, 2) @@ -3426,7 +3459,9 @@ def test_check_geo_pairwise_corr_info_only(self): self._mock_eda_engine_property("treatment_control_scaled_ds", mock_ds) outcome = engine.check_geo_pairwise_corr() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORR) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) self.assertLen(outcome.findings, 1) self.assertLen(outcome.analysis_artifacts, 2) @@ -3479,7 +3514,9 @@ def test_check_geo_pairwise_corr_high_overall_corr(self): self._mock_eda_engine_property("treatment_control_scaled_ds", mock_ds) outcome = engine.check_geo_pairwise_corr() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORR) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) self.assertLen(outcome.findings, 1) self.assertLen(outcome.analysis_artifacts, 2) @@ -3525,7 +3562,9 @@ def test_check_geo_pairwise_corr_high_corr_in_one_geo(self): self._mock_eda_engine_property("treatment_control_scaled_ds", mock_ds) outcome = engine.check_geo_pairwise_corr() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORR) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) self.assertLen(outcome.findings, 1) self.assertLen(outcome.analysis_artifacts, 2) @@ -3549,7 +3588,9 @@ def test_check_geo_pairwise_corr_corr_matrix_has_correct_coordinates(self): engine = eda_engine.EDAEngine(meridian) outcome = engine.check_geo_pairwise_corr() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORR) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) self.assertLen(outcome.analysis_artifacts, 2) for artifact in outcome.analysis_artifacts: @@ -3587,7 +3628,9 @@ def test_check_geo_pairwise_corr_correlation_values(self): self._mock_eda_engine_property("treatment_control_scaled_ds", mock_ds) outcome = engine.check_geo_pairwise_corr() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORR) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) expected_overall_corr = np.corrcoef( media_data.flatten(), control_data.flatten() )[0, 1] @@ -3650,7 +3693,9 @@ def test_check_national_pairwise_corr_one_error(self): ) outcome = engine.check_national_pairwise_corr() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORR) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) self.assertLen(outcome.findings, 1) self.assertLen(outcome.analysis_artifacts, 1) @@ -3690,7 +3735,9 @@ def test_check_national_pairwise_corr_info_only(self): ) outcome = engine.check_national_pairwise_corr() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORR) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) self.assertLen(outcome.findings, 1) self.assertLen(outcome.analysis_artifacts, 1) @@ -3714,7 +3761,9 @@ def test_check_national_pairwise_corr_corr_matrix_has_correct_coordinates( engine = eda_engine.EDAEngine(meridian) outcome = engine.check_national_pairwise_corr() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORR) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) self.assertLen(outcome.analysis_artifacts, 1) artifact = outcome.analysis_artifacts[0] self.assertEqual(artifact.level, eda_outcome.AnalysisLevel.NATIONAL) @@ -3746,7 +3795,9 @@ def test_check_national_pairwise_corr_correlation_values(self): ) outcome = engine.check_national_pairwise_corr() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORR) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) expected_corr = np.corrcoef(media_data.flatten(), control_data.flatten())[ 0, 1 ] @@ -3775,7 +3826,9 @@ def test_check_geo_std_std_artifacts_have_correct_coordinates(self): engine = eda_engine.EDAEngine(meridian) outcome = engine.check_geo_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.analysis_artifacts, 4) for artifact in outcome.analysis_artifacts: @@ -3812,7 +3865,9 @@ def test_check_geo_std_calculates_std_value_correctly(self): self._mock_eda_engine_property("kpi_scaled_da", mock_kpi_da) outcome = engine.check_geo_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.analysis_artifacts, 2) kpi_artifact = next( artifact @@ -3856,7 +3911,9 @@ def test_check_geo_std_correctly_identifies_outliers(self, outlier_value): self._mock_eda_engine_property("kpi_scaled_da", mock_kpi_da) outcome = engine.check_geo_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.analysis_artifacts, 2) kpi_artifact = next( artifact @@ -3895,7 +3952,9 @@ def test_check_geo_std_returns_info_finding_when_no_issues(self): self._mock_eda_engine_property("all_reach_scaled_da", None) self._mock_eda_engine_property("all_freq_da", None) outcome = engine.check_geo_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.findings, 1) self.assertEqual(outcome.findings[0].severity, eda_outcome.EDASeverity.INFO) self.assertIn( @@ -4010,7 +4069,9 @@ def test_check_geo_std_attention_cases( self._mock_eda_engine_property("all_freq_da", None) outcome = engine.check_geo_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.findings, 1) self.assertEqual( outcome.findings[0].severity, eda_outcome.EDASeverity.ATTENTION @@ -4037,7 +4098,9 @@ def test_check_geo_std_handles_missing_rf_data(self): self._mock_eda_engine_property("all_reach_scaled_da", None) self._mock_eda_engine_property("all_freq_da", None) outcome = engine.check_geo_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.analysis_artifacts, 2) variables = [artifact.variable for artifact in outcome.analysis_artifacts] self.assertCountEqual( @@ -4050,7 +4113,9 @@ def test_check_national_std_std_artifacts_have_correct_coordinates(self): engine = eda_engine.EDAEngine(meridian) outcome = engine.check_national_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.analysis_artifacts, 4) for artifact in outcome.analysis_artifacts: @@ -4087,7 +4152,9 @@ def test_check_national_std_calculates_std_value_correctly(self): self._mock_eda_engine_property("national_kpi_scaled_da", mock_kpi_da) outcome = engine.check_national_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.analysis_artifacts, 4) kpi_artifact = next( artifact @@ -4131,7 +4198,9 @@ def test_check_national_std_correctly_identifies_outliers( self._mock_eda_engine_property("national_kpi_scaled_da", mock_kpi_da) outcome = engine.check_national_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.analysis_artifacts, 4) kpi_artifact = next( artifact @@ -4173,7 +4242,9 @@ def test_check_national_std_returns_info_finding_when_no_issues(self): self._mock_eda_engine_property("national_all_reach_scaled_da", None) self._mock_eda_engine_property("national_all_freq_da", None) outcome = engine.check_national_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.findings, 1) self.assertEqual(outcome.findings[0].severity, eda_outcome.EDASeverity.INFO) self.assertIn( @@ -4204,7 +4275,9 @@ def test_check_national_std_finds_zero_std_kpi(self): self._mock_eda_engine_property("national_all_freq_da", None) outcome = engine.check_national_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.findings, 1) self.assertEqual( outcome.findings[0].severity, eda_outcome.EDASeverity.ATTENTION @@ -4318,7 +4391,9 @@ def test_check_national_std_attention_cases( self._mock_eda_engine_property("national_all_freq_da", None) outcome = engine.check_national_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.findings, 1) self.assertEqual( outcome.findings[0].severity, eda_outcome.EDASeverity.ATTENTION @@ -4347,7 +4422,9 @@ def test_check_national_std_handles_missing_rf_data(self): self._mock_eda_engine_property("national_all_reach_scaled_da", None) self._mock_eda_engine_property("national_all_freq_da", None) outcome = engine.check_national_std() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.STD) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.STANDARD_DEVIATION + ) self.assertLen(outcome.analysis_artifacts, 2) variables = [artifact.variable for artifact in outcome.analysis_artifacts] self.assertCountEqual( @@ -4408,7 +4485,9 @@ def test_check_geo_vif_returns_correct_finding_severity( outcome = engine.check_geo_vif() - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.VIF) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.MULTICOLLINEARITY + ) self.assertLen(outcome.findings, 1) self.assertEqual(outcome.findings[0].severity, expected_severity) self.assertIn(expected_explanation, outcome.findings[0].explanation) @@ -4425,7 +4504,9 @@ def test_check_geo_vif_overall_artifact_is_correct(self): outcome = engine.check_geo_vif() self.assertIsInstance(outcome, eda_outcome.EDAOutcome) - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.VIF) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.MULTICOLLINEARITY + ) self.assertLen(outcome.analysis_artifacts, 2) overall_artifact = next( @@ -4456,7 +4537,9 @@ def test_check_geo_vif_geo_artifact_is_correct(self): outcome = engine.check_geo_vif() self.assertIsInstance(outcome, eda_outcome.EDAOutcome) - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.VIF) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.MULTICOLLINEARITY + ) self.assertLen(outcome.analysis_artifacts, 2) geo_artifact = next( @@ -4493,7 +4576,9 @@ def test_check_geo_vif_has_correct_vif_value_when_vif_is_inf(self): outcome = engine.check_geo_vif() self.assertIsInstance(outcome, eda_outcome.EDAOutcome) - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.VIF) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.MULTICOLLINEARITY + ) self.assertLen(outcome.analysis_artifacts, 2) overall_artifact = next( @@ -4525,7 +4610,9 @@ def test_check_geo_vif_has_correct_vif_value(self): outcome = engine.check_geo_vif() self.assertIsInstance(outcome, eda_outcome.EDAOutcome) - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.VIF) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.MULTICOLLINEARITY + ) self.assertLen(outcome.analysis_artifacts, 2) overall_artifact = next( @@ -4635,7 +4722,9 @@ def test_check_national_vif_artifact_is_correct( outcome = engine.check_national_vif() self.assertIsInstance(outcome, eda_outcome.EDAOutcome) - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.VIF) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.MULTICOLLINEARITY + ) self.assertLen(outcome.analysis_artifacts, 1) national_artifact = outcome.analysis_artifacts[0] @@ -4664,7 +4753,9 @@ def test_check_national_vif_has_correct_vif_value_when_vif_is_inf(self): outcome = engine.check_national_vif() self.assertIsInstance(outcome, eda_outcome.EDAOutcome) - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.VIF) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.MULTICOLLINEARITY + ) self.assertLen(outcome.analysis_artifacts, 1) national_artifact = outcome.analysis_artifacts[0] @@ -4685,7 +4776,9 @@ def test_check_national_vif_has_correct_vif_value(self): outcome = engine.check_national_vif() self.assertIsInstance(outcome, eda_outcome.EDAOutcome) - self.assertEqual(outcome.check_type, eda_outcome.EDACheckType.VIF) + self.assertEqual( + outcome.check_type, eda_outcome.EDACheckType.MULTICOLLINEARITY + ) self.assertLen(outcome.analysis_artifacts, 1) national_artifact = outcome.analysis_artifacts[0] @@ -4704,6 +4797,98 @@ def test_check_national_vif_has_correct_vif_value(self): national_artifact.vif_da.values, expected_national_vif ) + @parameterized.named_parameters( + dict( + testcase_name="national_model", + is_national=True, + expected_call="check_national_std", + ), + dict( + testcase_name="geo_model", + is_national=False, + expected_call="check_geo_std", + ), + ) + def test_check_std_calls_correct_level(self, is_national, expected_call): + meridian = mock.Mock(spec=model.Meridian) + meridian.is_national = is_national + engine = eda_engine.EDAEngine(meridian) + + mock_outcome = self._create_eda_outcome( + eda_outcome.EDACheckType.STANDARD_DEVIATION, + eda_outcome.EDASeverity.INFO, + ) + mock_check = self.enter_context( + mock.patch.object( + engine, expected_call, autospec=True, return_value=mock_outcome + ) + ) + result = engine.check_std() + mock_check.assert_called_once() + self.assertEqual(result, mock_outcome) + + @parameterized.named_parameters( + dict( + testcase_name="national_model", + is_national=True, + expected_call="check_national_vif", + ), + dict( + testcase_name="geo_model", + is_national=False, + expected_call="check_geo_vif", + ), + ) + def test_check_vif_calls_correct_level(self, is_national, expected_call): + meridian = mock.Mock(spec=model.Meridian) + meridian.is_national = is_national + engine = eda_engine.EDAEngine(meridian) + + mock_outcome = self._create_eda_outcome( + eda_outcome.EDACheckType.MULTICOLLINEARITY, + eda_outcome.EDASeverity.INFO, + ) + mock_check = self.enter_context( + mock.patch.object( + engine, expected_call, autospec=True, return_value=mock_outcome + ) + ) + result = engine.check_vif() + mock_check.assert_called_once() + self.assertEqual(result, mock_outcome) + + @parameterized.named_parameters( + dict( + testcase_name="national_model", + is_national=True, + expected_call="check_national_pairwise_corr", + ), + dict( + testcase_name="geo_model", + is_national=False, + expected_call="check_geo_pairwise_corr", + ), + ) + def test_check_pairwise_corr_calls_correct_level( + self, is_national, expected_call + ): + meridian = mock.Mock(spec=model.Meridian) + meridian.is_national = is_national + engine = eda_engine.EDAEngine(meridian) + + mock_outcome = self._create_eda_outcome( + eda_outcome.EDACheckType.PAIRWISE_CORRELATION, + eda_outcome.EDASeverity.INFO, + ) + mock_check = self.enter_context( + mock.patch.object( + engine, expected_call, autospec=True, return_value=mock_outcome + ) + ) + result = engine.check_pairwise_corr() + mock_check.assert_called_once() + self.assertEqual(result, mock_outcome) + @parameterized.named_parameters( dict( testcase_name="has_variability", @@ -4834,6 +5019,120 @@ def test_check_overall_kpi_invariability_has_variability( kpi_data, ) + def test_run_all_critical_checks_all_pass(self): + meridian = model.Meridian(self.input_data_with_media_and_rf) + engine = eda_engine.EDAEngine(meridian) + + mock_results = { + "check_overall_kpi_invariability": self._create_eda_outcome( + eda_outcome.EDACheckType.KPI_INVARIABILITY, + eda_outcome.EDASeverity.INFO, + ), + "check_vif": self._create_eda_outcome( + eda_outcome.EDACheckType.MULTICOLLINEARITY, + eda_outcome.EDASeverity.INFO, + ), + "check_pairwise_corr": self._create_eda_outcome( + eda_outcome.EDACheckType.PAIRWISE_CORRELATION, + eda_outcome.EDASeverity.INFO, + ), + } + self._mock_critical_checks(mock_results) + + outcomes = engine.run_all_critical_checks() + + self.assertLen(outcomes, 3) + for outcome in outcomes: + self.assertLen(outcome.findings, 1) + self.assertEqual( + outcome.findings[0].severity, eda_outcome.EDASeverity.INFO + ) + + def test_run_all_critical_checks_with_non_info_findings(self): + meridian = model.Meridian(self.input_data_with_media_and_rf) + engine = eda_engine.EDAEngine(meridian) + + mock_results = { + "check_overall_kpi_invariability": self._create_eda_outcome( + eda_outcome.EDACheckType.KPI_INVARIABILITY, + eda_outcome.EDASeverity.ERROR, + ), + "check_vif": self._create_eda_outcome( + eda_outcome.EDACheckType.MULTICOLLINEARITY, + eda_outcome.EDASeverity.ATTENTION, + ), + "check_pairwise_corr": self._create_eda_outcome( + eda_outcome.EDACheckType.PAIRWISE_CORRELATION, + eda_outcome.EDASeverity.INFO, + ), + } + self._mock_critical_checks(mock_results) + + outcomes = engine.run_all_critical_checks() + + self.assertLen(outcomes, 3) + expected_severities = [ + eda_outcome.EDASeverity.ERROR, + eda_outcome.EDASeverity.ATTENTION, + eda_outcome.EDASeverity.INFO, + ] + for i, outcome in enumerate(outcomes): + self.assertLen(outcome.findings, 1) + self.assertEqual(outcome.findings[0].severity, expected_severities[i]) + + def test_run_all_critical_checks_with_exception(self): + meridian = model.Meridian(self.input_data_with_media_and_rf) + engine = eda_engine.EDAEngine(meridian) + + mock_results = { + "check_overall_kpi_invariability": self._create_eda_outcome( + eda_outcome.EDACheckType.KPI_INVARIABILITY, + eda_outcome.EDASeverity.INFO, + ), + "check_vif": ValueError("Test Error"), + "check_pairwise_corr": TypeError("Another Error"), + } + self._mock_critical_checks(mock_results) + + outcomes = engine.run_all_critical_checks() + + self.assertLen(outcomes, 3) + + # Check check_overall_kpi_invariability + self.assertEqual( + outcomes[0].check_type, eda_outcome.EDACheckType.KPI_INVARIABILITY + ) + self.assertLen(outcomes[0].findings, 1) + self.assertEqual( + outcomes[0].findings[0].severity, eda_outcome.EDASeverity.INFO + ) + + # Check check_vif (should catch ValueError) + self.assertEqual( + outcomes[1].check_type, eda_outcome.EDACheckType.MULTICOLLINEARITY + ) + self.assertLen(outcomes[1].findings, 1) + self.assertEqual( + outcomes[1].findings[0].severity, eda_outcome.EDASeverity.ERROR + ) + self.assertIn( + "An error occurred during check check_vif: Test Error", + outcomes[1].findings[0].explanation, + ) + + # Check check_pairwise_corr (should catch TypeError) + self.assertEqual( + outcomes[2].check_type, eda_outcome.EDACheckType.PAIRWISE_CORRELATION + ) + self.assertLen(outcomes[2].findings, 1) + self.assertEqual( + outcomes[2].findings[0].severity, eda_outcome.EDASeverity.ERROR + ) + self.assertIn( + "An error occurred during check check_pairwise_corr: Another Error", + outcomes[2].findings[0].explanation, + ) + if __name__ == "__main__": absltest.main() diff --git a/meridian/model/eda/eda_outcome.py b/meridian/model/eda/eda_outcome.py index daef8799b..f94384ec9 100644 --- a/meridian/model/eda/eda_outcome.py +++ b/meridian/model/eda/eda_outcome.py @@ -157,9 +157,9 @@ class KpiInvariabilityArtifact(AnalysisArtifact): class EDACheckType(enum.Enum): """Enumeration for the type of an EDA check.""" - PAIRWISE_CORR = enum.auto() - STD = enum.auto() - VIF = enum.auto() + PAIRWISE_CORRELATION = enum.auto() + STANDARD_DEVIATION = enum.auto() + MULTICOLLINEARITY = enum.auto() KPI_INVARIABILITY = enum.auto() diff --git a/meridian/model/eda/eda_outcome_test.py b/meridian/model/eda/eda_outcome_test.py index 13864dda8..623987e2f 100644 --- a/meridian/model/eda/eda_outcome_test.py +++ b/meridian/model/eda/eda_outcome_test.py @@ -73,12 +73,12 @@ class EdaOutcomeTest(parameterized.TestCase): extreme_corr_threshold=0.5, ) _GEO_OUTCOME = eda_outcome.EDAOutcome( - check_type=eda_outcome.EDACheckType.PAIRWISE_CORR, + check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION, findings=[], analysis_artifacts=[_OVERALL_ARTIFACT, _GEO_ARTIFACT], ) _NATIONAL_OUTCOME = eda_outcome.EDAOutcome( - check_type=eda_outcome.EDACheckType.PAIRWISE_CORR, + check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION, findings=[], analysis_artifacts=[_NATIONAL_ARTIFACT], ) diff --git a/meridian/model/model.py b/meridian/model/model.py index 69992dec1..d67f90f47 100644 --- a/meridian/model/model.py +++ b/meridian/model/model.py @@ -14,6 +14,7 @@ """Meridian module for the geo-level Bayesian hierarchical media mix model.""" +import collections from collections.abc import Mapping, Sequence import functools import numbers @@ -35,6 +36,7 @@ from meridian.model import spec from meridian.model import transformers from meridian.model.eda import eda_engine +from meridian.model.eda import eda_outcome from meridian.model.eda import eda_spec as eda_spec_module import numpy as np @@ -42,12 +44,17 @@ "MCMCSamplingError", "MCMCOOMError", "Meridian", + "ModelFittingError", "NotFittedModelError", "save_mmm", "load_mmm", ] +class ModelFittingError(Exception): + """Model has critical issues preventing fitting.""" + + class NotFittedModelError(Exception): """Model has not been fitted.""" @@ -94,6 +101,8 @@ class Meridian: resulting data from fitting the model. eda_engine: An `EDAEngine` object containing the EDA engine. eda_spec: An `EDASpec` object containing the EDA specification. + eda_outcomes: A list of `EDAOutcome` objects containing the outcomes from + running critical EDA checks. n_geos: Number of geos in the data. n_media_channels: Number of media channels in the data. n_rf_channels: Number of reach and frequency (RF) channels in the data. @@ -164,6 +173,7 @@ def __init__( self._inference_data = ( inference_data if inference_data else az.InferenceData() ) + self._eda_spec = eda_spec self._validate_data_dependent_model_spec() @@ -203,6 +213,10 @@ def eda_engine(self) -> eda_engine.EDAEngine: def eda_spec(self) -> eda_spec_module.EDASpec: return self._eda_spec + @functools.cached_property + def eda_outcomes(self) -> Sequence[eda_outcome.EDAOutcome]: + return self.eda_engine.run_all_critical_checks() + @functools.cached_property def media_tensors(self) -> media.MediaTensors: return media.build_media_tensors(self.input_data, self.model_spec) @@ -1586,6 +1600,36 @@ def sample_prior(self, n_draws: int, seed: int | None = None): """ self.prior_sampler_callable(n_draws=n_draws, seed=seed) + def _run_model_fitting_guardrail(self): + """Raises an error if the model has critical EDA issues.""" + error_findings_by_type: dict[eda_outcome.EDACheckType, list[str]] = ( + collections.defaultdict(list) + ) + for outcome in self.eda_outcomes: + error_findings = [ + finding + for finding in outcome.findings + if finding.severity == eda_outcome.EDASeverity.ERROR + ] + if error_findings: + error_findings_by_type[outcome.check_type].extend( + [finding.explanation for finding in error_findings] + ) + + if error_findings_by_type: + error_message_lines = [ + "Model has critical EDA issues. Please fix before running" + " `sample_posterior`.\n" + ] + for check_type, explanations in error_findings_by_type.items(): + error_message_lines.append(f"Check type: {check_type.name}") + for explanation in explanations: + error_message_lines.append(f"- {explanation}") + error_message_lines.append( + "For further details, please refer to `Meridian.eda_outcomes`." + ) + raise ModelFittingError("\n".join(error_message_lines)) + def sample_posterior( self, n_chains: Sequence[int] | int, @@ -1663,6 +1707,8 @@ def sample_posterior( [ResourceExhaustedError when running Meridian.sample_posterior] (https://developers.google.com/meridian/docs/post-modeling/model-debugging#gpu-oom-error). """ + self._run_model_fitting_guardrail() + self.posterior_sampler_callable( n_chains=n_chains, n_adapt=n_adapt, diff --git a/meridian/model/model_test.py b/meridian/model/model_test.py index e35e1a9bb..cd7775020 100644 --- a/meridian/model/model_test.py +++ b/meridian/model/model_test.py @@ -34,6 +34,7 @@ from meridian.model import prior_distribution from meridian.model import spec from meridian.model.eda import eda_engine +from meridian.model.eda import eda_outcome from meridian.model.eda import eda_spec as eda_spec_module import numpy as np import xarray as xr @@ -1668,6 +1669,50 @@ def test_load_error(self): ): model.load_mmm("this/path/does/not/exist") + def test_run_model_fitting_guardrail_error_message(self): + # Create mock EDA outcomes with ERROR severity findings + mock_finding1 = mock.Mock() + mock_finding1.severity = eda_outcome.EDASeverity.ERROR + mock_finding1.explanation = "Error explanation for PAIRWISE_CORR 1." + + mock_finding2 = mock.Mock() + mock_finding2.severity = eda_outcome.EDASeverity.ERROR + mock_finding2.explanation = "Error explanation for PAIRWISE_CORR 2." + + mock_finding3 = mock.Mock() + mock_finding3.severity = eda_outcome.EDASeverity.ERROR + mock_finding3.explanation = "Error explanation for MULTICOLLINEARITY 1." + + mock_outcome1 = mock.Mock() + mock_outcome1.check_type = eda_outcome.EDACheckType.PAIRWISE_CORRELATION + mock_outcome1.findings = [mock_finding1, mock_finding2] + + mock_outcome2 = mock.Mock() + mock_outcome2.check_type = eda_outcome.EDACheckType.MULTICOLLINEARITY + mock_outcome2.findings = [mock_finding3] + + mock_eda_outcomes = self.enter_context( + mock.patch( + "meridian.model.model.Meridian.eda_outcomes", + new_callable=mock.PropertyMock, + ) + ) + mock_eda_outcomes.return_value = [mock_outcome1, mock_outcome2] + meridian = model.Meridian(input_data=self.input_data_with_media_only) + + expected_error_message = ( + "Model has critical EDA issues. Please fix before running" + " `sample_posterior`.\n\nCheck type: PAIRWISE_CORRELATION\n- Error" + " explanation for PAIRWISE_CORR 1.\n- Error explanation for" + " PAIRWISE_CORR 2.\nCheck type: MULTICOLLINEARITY\n- Error explanation" + " for MULTICOLLINEARITY 1.\nFor further details, please refer to" + " `Meridian.eda_outcomes`." + ) + with self.assertRaisesWithLiteralMatch( + model.ModelFittingError, expected_error_message + ): + meridian.sample_posterior(n_chains=1, n_adapt=1, n_burnin=1, n_keep=1) + class NonPaidModelTest( test_utils.MeridianTestCase, diff --git a/meridian/model/posterior_sampler_test.py b/meridian/model/posterior_sampler_test.py index 3c85c9592..e3964a736 100644 --- a/meridian/model/posterior_sampler_test.py +++ b/meridian/model/posterior_sampler_test.py @@ -41,6 +41,17 @@ def setUpClass(cls): super().setUpClass() model_test_data.WithInputDataSamples.setup() + def setUp(self): + super().setUp() + self.enter_context( + mock.patch.object( + model.Meridian, + "_run_model_fitting_guardrail", + autospec=True, + return_value=None, + ) + ) + def _assert_seeds_equal(self, seed1, seed2): if backend.config.get_backend() == backend.config.Backend.JAX: self.assertEqual(seed1, seed2)