diff --git a/.github/actions/install-meridian/action.yml b/.github/actions/install-meridian/action.yml index a9d88826c..897e7f127 100644 --- a/.github/actions/install-meridian/action.yml +++ b/.github/actions/install-meridian/action.yml @@ -10,6 +10,10 @@ runs: using: composite steps: # Install deps + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + version: "27.x" - uses: actions/setup-python@v5 with: python-version: ${{ inputs.python_version }} @@ -17,6 +21,6 @@ runs: cache-dependency-path: '**/pyproject.toml' - shell: bash run: | - pip install -e .[dev,jax,mlflow] --config-settings editable_mode=strict + pip install -e .[dev,jax,mlflow,schema] --config-settings editable_mode=strict pip freeze diff --git a/proto/LICENSE b/proto/LICENSE new file mode 100644 index 000000000..7a4a3ea24 --- /dev/null +++ b/proto/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/proto/README.md b/proto/README.md new file mode 100644 index 000000000..06c4dd912 --- /dev/null +++ b/proto/README.md @@ -0,0 +1,21 @@ +# About MMM Proto Schema + +The MMM Proto Schema is a language-agnostic data standard that provides a +consistent and serializable way to represent a trained Marketing Mix Model (MMM) +and its analyses artifacts. Its core purpose is to establish a common language +for the outputs of an MMM. This allows the results from models built using +various tools or methodologies to be uniformly represented, stored, shared, and +compared by various applications and workflows. By offering this standardized +representation, the schema aims to enhance interoperability and facilitate +downstream applications, such as scenario planning, optimization, and consistent +reporting, independent of how the original model was constructed. + +## Install Meridian with MMM Proto Schema + +Currently, this package can only be installed from source code: + +```sh +git clone https://github.com/google/meridian.git; +cd meridian; +pip install .[schema]; +``` diff --git a/proto/__init__.py b/proto/__init__.py new file mode 100644 index 000000000..de5fca83d --- /dev/null +++ b/proto/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for MMM Proto Schema.""" +__version__ = "1.0.0" diff --git a/proto/mmm/v1/common/date_interval.proto b/proto/mmm/v1/common/date_interval.proto new file mode 100644 index 000000000..3876305bd --- /dev/null +++ b/proto/mmm/v1/common/date_interval.proto @@ -0,0 +1,34 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.common; + +import "google/type/date.proto"; + +option features.field_presence = IMPLICIT; +option java_multiple_files = true; +option java_package = "com.google.protos.mmm.v1.common"; + +message DateInterval { + // The start date of the interval. Inclusive. Required. + google.type.Date start_date = 1; + + // The end date of the interval. Exclusive. Required. + google.type.Date end_date = 2; + + // A tag to identify the date interval. Optional. + string tag = 3; +} diff --git a/proto/mmm/v1/common/estimate.proto b/proto/mmm/v1/common/estimate.proto new file mode 100644 index 000000000..efb56a14b --- /dev/null +++ b/proto/mmm/v1/common/estimate.proto @@ -0,0 +1,43 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.common; + +option features.field_presence = IMPLICIT; +option java_multiple_files = true; + +// Contains an estimate value of estimand and associated quantified +// uncertainties. +message Estimate { + // The estimate value of an estimand. Required. + double value = 1; + + // The uncertainty of an estimate quantified by probability interval. + message Uncertainty { + // The probability that a value is inside an interval bounded by lowerbound + // and upperbound. Required. + double probability = 1; + + // The lowerbound of the interval. Required. + double lowerbound = 2; + + // The upperbound of the interval. Required. + double upperbound = 3; + } + + // The quantified uncertainties. + repeated Uncertainty uncertainties = 2; +} diff --git a/proto/mmm/v1/common/kpi_type.proto b/proto/mmm/v1/common/kpi_type.proto new file mode 100644 index 000000000..53b725ec4 --- /dev/null +++ b/proto/mmm/v1/common/kpi_type.proto @@ -0,0 +1,33 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.common; + +option java_multiple_files = true; +option java_package = "com.google.protos.mmm.v1.common"; + + +// Different KPI types used in marketing performance and optimization. +enum KpiType { + KPI_TYPE_UNSPECIFIED = 0; + + // Some generic user-defined KPI unit. + NON_REVENUE = 1; + + // KPI defined as revenue specifically, or some KPI unit after conversion + // to revenue. + REVENUE = 2; +} diff --git a/proto/mmm/v1/common/target_metric.proto b/proto/mmm/v1/common/target_metric.proto new file mode 100644 index 000000000..46c555a8d --- /dev/null +++ b/proto/mmm/v1/common/target_metric.proto @@ -0,0 +1,41 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.common; + +option java_multiple_files = true; + +// Target metrics for marketing performance optimizations. +// Note that each of these metric variants can be interpreted in terms of either +// revenue or non-revenue KPI _type_. See: `common.KpiType`. +enum TargetMetric { + TARGET_METRIC_UNSPECIFIED = 0; + + // Any KPI type (revenue if the model data can be converted to revenue, or + // some generic KPI otherwise). + KPI = 1; + + // ROI = net KPI change / spend. + ROI = 2; + + // Marginal ROI is defined at channel level and is applied across all + // channels. + // mROI = changed in incremental KPI / spend. + MARGINAL_ROI = 3; + + // CPIK = total spend / change in incremental KPI. + COST_PER_INCREMENTAL_KPI = 4; +} diff --git a/proto/mmm/v1/fit/model_fit.proto b/proto/mmm/v1/fit/model_fit.proto new file mode 100644 index 000000000..695fb1e83 --- /dev/null +++ b/proto/mmm/v1/fit/model_fit.proto @@ -0,0 +1,70 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.fit; + +import "mmm/v1/common/date_interval.proto"; +import "mmm/v1/common/estimate.proto"; + +option features.field_presence = IMPLICIT; +option java_multiple_files = true; + +// A prediction contains the predicted KPI and the ground truth at a specific +// time. +message Prediction { + // The time associated to this prediction point. Required. + common.DateInterval date_interval = 1; + + // The predicted outcome. Required. + common.Estimate predicted_outcome = 2; + + // The predicted baseline. Optional. + common.Estimate predicted_baseline = 3; + + // The actual value observed in the data. Required. + double actual_value = 4; +} + +// The model fit performance indicated by different metrics. +message Performance { + double r_squared = 1; + + // Mean absolute percentage error. + double mape = 2; + + // Is equal to sum_i(actual_i - pred_i) / sum_i(actual_i) + double weighted_mape = 3; + + // Root mean square error. + double rmse = 4; +} + +message Result { + // The name of the result. Required. + string name = 1; + + // The predictions over different times. Required. + repeated Prediction predictions = 2; + + // The performance of the model fit. Required. + Performance performance = 3; +} + +message ModelFit { + // Different results for different purposes. For example, one could divide a + // data to training set, testing set, and validation set. Required. + repeated Result results = 1; +} diff --git a/proto/mmm/v1/marketing/analysis/marketing_analysis.proto b/proto/mmm/v1/marketing/analysis/marketing_analysis.proto new file mode 100644 index 000000000..05df68ba2 --- /dev/null +++ b/proto/mmm/v1/marketing/analysis/marketing_analysis.proto @@ -0,0 +1,41 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.marketing.analysis; + +import "mmm/v1/common/date_interval.proto"; +import "mmm/v1/marketing/analysis/media_analysis.proto"; +import "mmm/v1/marketing/analysis/non_media_analysis.proto"; + +option java_multiple_files = true; + +// The marketing analysis. +message MarketingAnalysis { + // The date interval that the analysis covers. Required. + common.DateInterval date_interval = 1; + + // Analysis on different media channels. Required. + repeated MediaAnalysis media_analyses = 2; + + // Analysis on different non-media factors. + repeated NonMediaAnalysis non_media_analyses = 3; +} + +// A list of marketing analyses. +message MarketingAnalysisList { + // The marketing analyses for different time ranges. Required. + repeated MarketingAnalysis marketing_analyses = 1; +} diff --git a/proto/mmm/v1/marketing/analysis/media_analysis.proto b/proto/mmm/v1/marketing/analysis/media_analysis.proto new file mode 100644 index 000000000..0c81360dc --- /dev/null +++ b/proto/mmm/v1/marketing/analysis/media_analysis.proto @@ -0,0 +1,60 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.marketing.analysis; + +import "mmm/v1/marketing/analysis/outcome.proto"; +import "mmm/v1/marketing/analysis/response_curve.proto"; + +option features.field_presence = IMPLICIT; +option java_multiple_files = true; + +message SpendInfo { + // The amount spent on the media channel. Required + double spend = 1; + + // Spend share = spend / total spend. Required + double spend_share = 2; +} + +// An analysis on a media channel between its spend variable and KPI(s) outcome. +message MediaAnalysis { + reserved 3; + + // The name of the media channel. Required. + string channel_name = 1; + + // The spend information of this media channel. + // + // This is optional and is left unset for a non-paid media channel analysis. + SpendInfo spend_info = 2; + + // The marketing outcomes of advertising from this media channel. Required. + // + // The outcome(s) calculated based on revenue and/or generic non-revenue KPI. + // One or more outcome values are set when revenue and/or generic non-revenue + // KPI outcome information is available. + // + // For a non-paid media, the spend-related fields in `Outcome` won't be + // set. + repeated Outcome media_outcomes = 5; + + // A response curve for the media channel. + // + // When the media is a non-paid media, response curve is not available due to + // the lack of spend information. + ResponseCurve response_curve = 4; +} diff --git a/proto/mmm/v1/marketing/analysis/non_media_analysis.proto b/proto/mmm/v1/marketing/analysis/non_media_analysis.proto new file mode 100644 index 000000000..1ebd6574d --- /dev/null +++ b/proto/mmm/v1/marketing/analysis/non_media_analysis.proto @@ -0,0 +1,40 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.marketing.analysis; + +import "mmm/v1/marketing/analysis/outcome.proto"; + +option features.field_presence = IMPLICIT; +option java_multiple_files = true; + +// The analysis on non-media factor. +message NonMediaAnalysis { + reserved 2; + + // The name of the non-media factor. Required. + string non_media_name = 1; + + // The marketing outcomes from this non-media factor. Required. + // + // The non-media outcome(s) calculated based on revenue and/or generic + // non-revenue KPI. One or more outcome values are set when revenue and/or + // generic non-revenue KPI outcome information is available. + // + // The spend-related fields herein won't be set as non-media factor doesn't + // have spend. + repeated Outcome non_media_outcomes = 3; +} diff --git a/proto/mmm/v1/marketing/analysis/outcome.proto b/proto/mmm/v1/marketing/analysis/outcome.proto new file mode 100644 index 000000000..99df64620 --- /dev/null +++ b/proto/mmm/v1/marketing/analysis/outcome.proto @@ -0,0 +1,77 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.marketing.analysis; + +import "mmm/v1/common/estimate.proto"; +import "mmm/v1/common/kpi_type.proto"; + +option features.field_presence = IMPLICIT; +option java_multiple_files = true; + +// A contribution is defined as the value difference caused by adding a new +// outcome. +message Contribution { + // The contribution value. Required. + common.Estimate value = 1; + + // Share of contribution = contribution / total contribution from all + // outcomes. + common.Estimate share = 2; +} + +// Effectiveness calculates how much incremental KPI is generated per media unit +// (e.g. impressions or clicks), i.e. contribution / media unit. +message Effectiveness { + // The media unit of the effectiveness. Required. + string media_unit = 1; + + // The value of the effectiveness. Required. + common.Estimate value = 2; +} + +// An outcome analysis on KPI which can be defined as revenue or other generic +// non-revenue type. +message Outcome { + // The type of this KPI (i.e. REVENUE or NON_REVENUE) + // Note that a model input with non-revenue (generic KPI) data can still have + // revenue-based KPI outcomes defined, provided that `revenue_per_kpi` is + // defined. + common.KpiType kpi_type = 1; + + // The contribution to a KPI. + // If `kpi_type == REVENUE`, this is the revenue KPI value. + // If `kpi_type == NON_REVENUE`, AND there is a `revenue_per_kpi` conversion, + // this is the derived `kpi * revenue_per_kpi` value. + // Otherwise, this is simply the (non-revenue, user-defined) KPI value. + Contribution contribution = 2; + + // The effectiveness of this outcome. + Effectiveness effectiveness = 3; + + // ROI = contribution / spend. + // See contribution definition above. + common.Estimate roi = 4; + + // Marginal ROI shows the additional ROI gained from additional spend. + // See ROI definition above. + common.Estimate marginal_roi = 5; + + // Cost per incremental outcome (which could be revenue or some generic KPI). + // E.g. when contribution is thousand impression, it is CPM; when contribution + // is acquisition, this is CPA. + common.Estimate cost_per_contribution = 6; +} diff --git a/proto/mmm/v1/marketing/analysis/response_curve.proto b/proto/mmm/v1/marketing/analysis/response_curve.proto new file mode 100644 index 000000000..ae4c2b7ff --- /dev/null +++ b/proto/mmm/v1/marketing/analysis/response_curve.proto @@ -0,0 +1,39 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.marketing.analysis; + +option features.field_presence = IMPLICIT; +option java_multiple_files = true; + +message ResponsePoint { + // The amount of the input that drives the incremental KPI. Required. + double input_value = 1; + + // The incremental KPI caused by the input. Required. + double incremental_kpi = 2; +} + +// A response curve is used to show how much incremental outcome moves in +// response to changes to the input value (e.g. spend amount on a paid media +// channel, advertising impressions in a channel, etc). +message ResponseCurve { + // The name of the input. Required. + string input_name = 1; + + // The response points. Required. + repeated ResponsePoint response_points = 2; +} diff --git a/proto/mmm/v1/marketing/marketing_data.proto b/proto/mmm/v1/marketing/marketing_data.proto new file mode 100644 index 000000000..6b917de1b --- /dev/null +++ b/proto/mmm/v1/marketing/marketing_data.proto @@ -0,0 +1,278 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.marketing; + +import "google/api/field_behavior.proto"; +import "google/type/date.proto"; +import "mmm/v1/common/date_interval.proto"; + +option java_multiple_files = true; + +message GeoInfo { + // The ID of the geo location. Required. + string geo_id = 1; + + // The population of the geo location. Required. + int64 population = 2; +} + +// A KPI (key performance indicator) can either be revenue directly, or some +// other metric that indirectly contributes to revenue eventually (e.g. sales +// units, conversions, impressions, etc) with a multiplier value to roughly +// translates said non-revenue KPI unit into revenue. +message Kpi { + // A revenue KPI. + message Revenue { + // The revenue value. + double value = 1; + } + + // A non-revenue type of KPI. + message NonRevenue { + // The value of the non-revenue KPI. + double value = 1; + + // Used to convert non-revenue KPI value to revenue. + // + // Needs to be non-negative. + double revenue_per_kpi = 2; + } + + // The name of the KPI. Required. + string name = 1; + + // The type of the KPI. Required + oneof type { + Revenue revenue = 2; + NonRevenue non_revenue = 3; + } +} + +// The control variable. A control variable is not directly being studied but +// included in the model to account for potential confounding effects on the +// relationship between the primary independent and dependent variables. +// Examples: seasonality factors, and macroeconomic factors. +message ControlVariable { + // The name of the variable. Required. + string name = 1; + + // The value of the variable. Required. + double value = 2; +} + +// The non-media treatment variable. A marketing activity that is not directly +// related to media, such as running a promotion, the price of a product, and +// a change in a product's packaging or design. +message NonMediaTreatmentVariable { + // The name of the variable. Required. + string name = 1; + + // The value of the variable. Required. + double value = 2; +} + +// Scalar type of metrics. Ex: Impressions, clicks, costs, and etc. +message ScalarMetric { + // The name of the scalar metric. Required. + string name = 1; + + // The value of the scalar metric. Required. + double value = 2; +} + +// Reach and frequency metric. +message ReachFrequency { + // Reach value. Required. + int64 reach = 1; + + // Average frequency value. Required. + double average_frequency = 2; +} + +// The media variable. +message MediaVariable { + reserved 3; + + // The name of the media channel. Required. + string channel_name = 1; + + // Scalar metric measured on the channel. Required. + ScalarMetric scalar_metric = 2; + + // Spend on the media. + // + // If the media variable is a paid media, spend is required. + double media_spend = 4; +} + +// The reach and frequency variable. +message ReachFrequencyVariable { + // The name of the reach and frequency variable. Required. + string channel_name = 1; + + // The reach value. Required. + int64 reach = 2; + + // The average frequency value. Required. + double average_frequency = 3; + + // The spend value. + // + // If the reach and frequency variable is a paid media, spend is required. + double spend = 4; +} + +// A data point contains marketing information at specific geo and time. +message MarketingDataPoint { + // Geo info of this data point. + // If unset, this data point is aggregated across all geos in the model's + // geo coordinates. + GeoInfo geo_info = 1; + + // Date interval covered by this data point. Required. + // This can represent either a coordinate point, or an aggregation over a time + // dimension's coordinates. In the case of the latter, this field should be + // defined with `[start, end + interval]` value, where `start` and `end` are + // the first and last coordinate in that time dimension, respectively. + common.DateInterval date_interval = 2; + + // The control variables associated to the marketing at this geo and time. + repeated ControlVariable control_variables = 3; + + // The media variables associated to the marketing at this geo and time. + // + // If a media variable is from a paid media channel and its media spend + // breakdown by geo and time is not available (i.e. media spend is aggregated + // across all geos and times), then there should be a separate + // `MarketingDataPoint` message with `media_spend` where `geo_info` is unset + // and `date_interval` spans the entire time dimension's coordinates. + // + // Media channel names should be unique across this group of media variables. + repeated MediaVariable media_variables = 4; + + // The reach and frequency variables associated to the marketing at this geo + // and time. + // + // If a reach and frequency variable is from a paid media channel and its + // spend breakdown by geo and time is not available (i.e. spend is + // aggregated across all geos and times), then there should be a separate + // `MarketingDataPoint` message with `spend` where `geo_info` is unset and + // `date_interval` spans the entire time dimension's coordinates. + // + // Reach and frequency variable names should be unique across this group of + // reach and frequency variables. + repeated ReachFrequencyVariable reach_frequency_variables = 6; + + // The KPI associated to the marketing at this geo and time. + // Kpi type must be consistent across all data points. + Kpi kpi = 5; + + // Non-media treatment variables associated with this data point. + repeated NonMediaTreatmentVariable non_media_treatment_variables = 7; +} + +// Metadata support useful to validate data points and recreate a model data in +// its domain language. +message MarketingDataMetadata { + // A named set of time coordinates. + message TimeDimension { + // A name for this set of time coordinates. Optional. + string name = 1; + + // The coordinates of this time dimension in the model, as dates. Required. + repeated google.type.Date dates = 2; + } + + // One or more sets of time coordinates. Required. + repeated TimeDimension time_dimensions = 1; + + message GeoDimension { + repeated string geo_coordinates = 1; + } + + // The geo dimension of the model. Required. + GeoDimension geo_dimension = 2; + + // A named set of channel name coordinates. + message ChannelDimension { + // A name for this set of channel names in this dimension. Optional. + string name = 1; + + // The channel names in this set's dimensional coordinates. Required. + repeated string channels = 2; + } + + // One or more sets of channel names. Required. + repeated ChannelDimension channel_dimensions = 3; + + // Control variable names. + repeated string control_names = 4; + + // The KPI type. + string kpi_type = 5; + + // Non-media treatment variable names. + repeated string non_media_treatment_names = 6; +} + +// A collection of marketing data points for different combinations of geo +// locations and times, used for model training. +message MarketingData { + // The marketing data points. Required. + repeated MarketingDataPoint marketing_data_points = 1; + + // Metadata support useful to recreate a model data in its domain language. + MarketingDataMetadata metadata = 2; +} + +// A new marketing data point used for model inference. This contains +// independent marketing data at a specific geo and time. +message NewMarketingDataPoint { + // Geo info of this data point. Required for a geo model. + GeoInfo geo_info = 1 [(google.api.field_behavior) = OPTIONAL]; + + // Date interval covered by this data point. + // This can represent either a coordinate point, or an aggregation over a time + // dimension's coordinates. In the case of the latter, this field should be + // defined with `[start, end + interval]` value, where `start` and `end` are + // the first and last coordinate in that time dimension, respectively. + common.DateInterval date_interval = 2 + [(google.api.field_behavior) = REQUIRED]; + + // The media variables associated with the marketing at this geo and time. + repeated MediaVariable media_variables = 3 + [(google.api.field_behavior) = OPTIONAL]; + + // The reach and frequency variables associated with the marketing at this geo + // and time. + repeated ReachFrequencyVariable reach_frequency_variables = 5 + [(google.api.field_behavior) = OPTIONAL]; + + // The revenue per KPI associated with the marketing at this geo and time. + // Required for revenue analysis. + double revenue_per_kpi = 4 [(google.api.field_behavior) = OPTIONAL]; +} + +// A collection of independent marketing data points for different combinations +// of geo locations and times, used to override training data for model +// inference. This can span any time period, including times overlapping with +// and beyond the modeling period. +message NewMarketingData { + // The independent marketing data points for each geo and time. + repeated NewMarketingDataPoint marketing_data_points = 1 + [(google.api.field_behavior) = OPTIONAL]; +} diff --git a/proto/mmm/v1/marketing/optimization/budget_optimization.proto b/proto/mmm/v1/marketing/optimization/budget_optimization.proto new file mode 100644 index 000000000..b11c742a2 --- /dev/null +++ b/proto/mmm/v1/marketing/optimization/budget_optimization.proto @@ -0,0 +1,182 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.marketing.optimization; + +import "mmm/v1/common/date_interval.proto"; +import "mmm/v1/common/estimate.proto"; +import "mmm/v1/common/kpi_type.proto"; +import "mmm/v1/common/target_metric.proto"; +import "mmm/v1/marketing/analysis/marketing_analysis.proto"; +import "mmm/v1/marketing/marketing_data.proto"; +import "mmm/v1/marketing/optimization/constraints.proto"; + +option features.field_presence = IMPLICIT; +option java_multiple_files = true; + +// A fixed budget scenario for optimizing budget allocations over channels. +message FixedBudgetScenario { + // The budget amount. Required. + double total_budget = 1; +} + +// A flexible budget scenario for optimizing budget allocations over channels. +message FlexibleBudgetScenario { + // The constraint parameters on the total budget. + BudgetConstraint total_budget_constraint = 1; + + // The constraints on target metrics (e.g. KPI, ROI, etc). + repeated TargetMetricConstraint target_metric_constraints = 2; +} + +// Channel level constraint +message ChannelConstraint { + // The name of the channel. Required. + string channel_name = 1; + + // The budget constraint on the channel. + BudgetConstraint budget_constraint = 2; +} + +// Input to the optimizer. +message BudgetOptimizationSpec { + // The date interval defines the selection of the time points that the + // optimization is based upon. + common.DateInterval date_interval = 1; + + // The objective to maximize in the budget optimization. Required. + common.TargetMetric objective = 2; + + // The type of KPI used to derive the optimization objective. Required. + common.KpiType kpi_type = 6; + + // The new marketing data to override the flighting pattern and CPM. + // If not provided, the optimization will be based on the historical data. + marketing.NewMarketingData new_marketing_data = 7; + + // The optimization scenario. Required. + oneof scenario { + // A fixed budget optimization tries to maximize an objective by optimizing + // the budget allocations over channels without changing the total budget + // amount. + // + // For instance, in Meridian, the objective function is chosen to be the + // posterior mean of the expected KPI (e.g. sales, revenue, etc). + FixedBudgetScenario fixed_budget_scenario = 3; + + // A flexible budget optimization tries to maximize an objective by + // optimizing the budget allocations over channels with a flexible total + // budget amount. + // + // For instance, in Meridian, the expected KPI (e.g. revenue) can be + // optimized while allowing the total budget to vary: this flexible + // optimization in Meridian is constrained on either the minimal marginal + // ROI or the target ROI constraints. + FlexibleBudgetScenario flexible_budget_scenario = 4; + } + + // The constraints on channels. + // If a media channel that is present in the model is not represented here, + // it will be given the default constraint of `[0, max_budget]`. + repeated ChannelConstraint channel_constraints = 5; +} + +// A message representing a grid that details the incremental outcome of +// marketing spend by channel. +// +// Note that this grid is constructed under the assumption that there is no +// interaction effect across channels, i.e. the spend on one channel will not +// affect other channels. +message IncrementalOutcomeGrid { + // A data point within the grid representing the outcome of a specific spend + // on a particular channel. + message Cell { + // The amount of marketing spend allocated to the channel. + double spend = 1; + + // The incremental outcome achieved through the channel spend. The type + // should be indicated by the objective in the optimization spec. + // + // This is calculated as the difference between the outcome with the given + // spend and the outcome with zero spend (outcome(spend) - outcome(0)). + common.Estimate incremental_outcome = 2; + } + + // A collection of cells in a channel. + message ChannelCells { + // The name of the marketing channel. + string channel_name = 1; + + // The cells in the channel. + repeated Cell cells = 2; + } + + // The name of the grid. Required. + string name = 1; + + // The uniform step size between consecutive spend values within a channel. + // Required. + double spend_step_size = 2; + + // The collection of cells representing all combinations of spend and + // incremental outcome across channels. Required. + // + // Each channel can have a different spend range, but all spend values within + // a channel must be spaced evenly using the specified step size. + repeated ChannelCells channel_cells = 3; +} + +// The budget optimization finds the result of optimal budget allocation given +// an optimization spec. +message BudgetOptimizationResult { + reserved 4; + + // The name of the budget optimization. Required. + string name = 1; + + // An optional identifier for a result that belongs to a group of related + // results (of different types). + // Note that no two `BudgetOptimizationResult`s should share the same group ID + // Simple UUID strings are recommended. + string group_id = 5; + + // The optimization spec used to generate the result. Required. + BudgetOptimizationSpec spec = 2; + + // The analysis on marketing outcome when using the optimized budget. + // Required. + // + // The non-media outcomes are not optimized, but some fields might be impacted + // by the change of media outcomes. For example, total contribution would + // change, so the contribution share values have to modified accordingly. + analysis.MarketingAnalysis optimized_marketing_analysis = 3; + + // The non-optimized marketing outcome. Required. + // + // In a fixed budget scenario, non-optimized marketing outcome is based on the + // budget amount. In a flexible budget scenario, the outcome is based on the + // historical spend. + analysis.MarketingAnalysis nonoptimized_marketing_analysis = 7; + + // Optional search grid that describes incremental outcomes of spends on + // channels. Useful for speeding up optimization analysis. + IncrementalOutcomeGrid incremental_outcome_grid = 6; +} + +message BudgetOptimization { + // Optimization results for different scenarios. + repeated BudgetOptimizationResult results = 1; +} diff --git a/proto/mmm/v1/marketing/optimization/constraints.proto b/proto/mmm/v1/marketing/optimization/constraints.proto new file mode 100644 index 000000000..0a56a41a6 --- /dev/null +++ b/proto/mmm/v1/marketing/optimization/constraints.proto @@ -0,0 +1,53 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.marketing.optimization; + +import "mmm/v1/common/target_metric.proto"; + +option features.field_presence = IMPLICIT; +option java_multiple_files = true; + +// The constraint of target metrics. +message TargetMetricConstraint { + reserved 2, 3; + + // The type of the target metric that is constrained. Required. + common.TargetMetric target_metric = 1; + + // The constraint on the target metric value. Required. + // Whether this target value represents a lower or upper bound depends on the + // target metric set above. + double target_value = 4; +} + +message BudgetConstraint { + // Required. + // Absolute minimum budget value. + double min_budget = 1; + + // Required. + // Absolute maximum budget value. + double max_budget = 2; +} + +message FrequencyConstraint { + // Required. + double min_frequency = 1; + + // Required. + double max_frequency = 2; +} diff --git a/proto/mmm/v1/marketing/optimization/marketing_optimization.proto b/proto/mmm/v1/marketing/optimization/marketing_optimization.proto new file mode 100644 index 000000000..437208cdc --- /dev/null +++ b/proto/mmm/v1/marketing/optimization/marketing_optimization.proto @@ -0,0 +1,31 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.marketing.optimization; + +import "mmm/v1/marketing/optimization/budget_optimization.proto"; +import "mmm/v1/marketing/optimization/reach_frequency_optimization.proto"; + +option java_multiple_files = true; + +// Marketing optimization contains all optimization related results. +message MarketingOptimization { + // Budget optimization that contains results for different scenarios. + BudgetOptimization budget_optimization = 1; + + // Reach frequency optimization that contains results for different scenarios. + ReachFrequencyOptimization reach_frequency_optimization = 2; +} diff --git a/proto/mmm/v1/marketing/optimization/reach_frequency_optimization.proto b/proto/mmm/v1/marketing/optimization/reach_frequency_optimization.proto new file mode 100644 index 000000000..ea79c4936 --- /dev/null +++ b/proto/mmm/v1/marketing/optimization/reach_frequency_optimization.proto @@ -0,0 +1,148 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.marketing.optimization; + +import "mmm/v1/common/date_interval.proto"; +import "mmm/v1/common/estimate.proto"; +import "mmm/v1/common/kpi_type.proto"; +import "mmm/v1/common/target_metric.proto"; +import "mmm/v1/marketing/analysis/marketing_analysis.proto"; +import "mmm/v1/marketing/marketing_data.proto"; +import "mmm/v1/marketing/optimization/constraints.proto"; + +option features.field_presence = IMPLICIT; +option java_multiple_files = true; + +// Channel level constraint for channel that has reach frequency information. +message RfChannelConstraint { + // The name of the channel. Required. + string channel_name = 1; + + // The budget constraint on the channel. + BudgetConstraint budget_constraint = 2; + + // The frequency constraint on the channel. + FrequencyConstraint frequency_constraint = 3; +} + +message ReachFrequencyOptimizationSpec { + // The date interval defines the selection of the time points that the + // optimization is based upon. + common.DateInterval date_interval = 1; + + // The objective to maximize in the reach frequency optimization. Required. + common.TargetMetric objective = 2; + + // The type of KPI used to derive the optimization objective. Required. + common.KpiType kpi_type = 6; + + // The constraints on target metrics (e.g. KPI, ROI, etc). + repeated TargetMetricConstraint target_metric_constraints = 3; + + // The constraint on the total budget. + BudgetConstraint total_budget_constraint = 4; + + // The constraints on channels that have reach frequency information. + repeated RfChannelConstraint rf_channel_constraints = 5; +} + +// Reach frequency optimization result for channel that has reach frequency +// information. +message OptimizedChannelFrequency { + // The name of the channel. Required. + string channel_name = 1; + + // The optimal average frequency of the channel. Required. + double optimal_average_frequency = 2; +} + +// A message representing a grid that details the outcome of reach frequency by +// channel. +// +// Note that this grid is constructed under the assumption that there is no +// interaction effect across channels, i.e. the reach frequency on one channel +// will not affect other channels. +message FrequencyOutcomeGrid { + // A data point within the grid representing the outcome of a specific reach + // frequency on a particular channel. + message Cell { + // The reach and frequency of the channel. + ReachFrequency reach_frequency = 1; + + // The outcome achieved through the channel reach frequency. The type should + // be indicated by the objective in the optimization spec. + common.Estimate outcome = 2; + } + + // A collection of cells in a channel. + message ChannelCells { + // The name of the marketing channel. + string channel_name = 1; + + // The cells in the channel. + repeated Cell cells = 2; + } + + // The name of the grid. Required. + string name = 1; + + // The uniform step size between consecutive frequency values within a + // channel. Required. + double frequency_step_size = 2; + + // The collection of cells representing all combinations of reach frequency + // and outcome across channels. Required. + repeated ChannelCells channel_cells = 3; +} + +message ReachFrequencyOptimizationResult { + reserved 5; + + // The name of the reach frequency optimization. Required. + string name = 1; + + // An optional identifier for a result that belongs to a group of related + // results (of different types). + // Note that no two `ReachFrequencyOptimizationResult`s should share the same + // group ID. + // Simple UUID strings are recommended. + string group_id = 6; + + // The optimization spec used to generate the result. Required. + ReachFrequencyOptimizationSpec spec = 2; + + // Optimal average frequency results by channel. Required. + // + // Media channels without reach frequency data won't be shown. + repeated OptimizedChannelFrequency optimized_channel_frequencies = 3; + + // The analysis on marketing outcome when using the optimal channel + // frequencies. Required. + // + // The non-media outcomes are not optimized, but some fields might be impacted + // by the change of media outcomes. For example, total contribution would + // change, so the contribution share values have to modified accordingly. + analysis.MarketingAnalysis optimized_marketing_analysis = 4; + + // Optional grid that describes the outcomes of reach frequency on channels. + FrequencyOutcomeGrid frequency_outcome_grid = 7; +} + +message ReachFrequencyOptimization { + // Optimization results for different scenarios. + repeated ReachFrequencyOptimizationResult results = 1; +} diff --git a/proto/mmm/v1/mmm.proto b/proto/mmm/v1/mmm.proto new file mode 100644 index 000000000..4ee96f0ed --- /dev/null +++ b/proto/mmm/v1/mmm.proto @@ -0,0 +1,42 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1; + +import "mmm/v1/fit/model_fit.proto"; +import "mmm/v1/marketing/analysis/marketing_analysis.proto"; +import "mmm/v1/marketing/optimization/marketing_optimization.proto"; +import "mmm/v1/model/mmm_kernel.proto"; + +option java_multiple_files = true; +option java_package = "com.google.protos.mmm.v1"; + +// A schema that contains derived metrics and modeled analysis by a trained +// marketing mix model. +message Mmm { + // A MMM kernel contains the core information about the model used to generate + // this output. + model.MmmKernel mmm_kernel = 1; + + // Model fit result. + fit.ModelFit model_fit = 2; + + // A list of marketing analysis generated by the MMM kernel. + marketing.analysis.MarketingAnalysisList marketing_analysis_list = 3; + + // Marketing optimization on different perspectives using the MMM kernel. + marketing.optimization.MarketingOptimization marketing_optimization = 4; +} diff --git a/proto/mmm/v1/model/meridian/meridian_model.proto b/proto/mmm/v1/model/meridian/meridian_model.proto new file mode 100644 index 000000000..170024368 --- /dev/null +++ b/proto/mmm/v1/model/meridian/meridian_model.proto @@ -0,0 +1,661 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.model.meridian; + +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; + +option features.field_presence = IMPLICIT; +option java_package = "com.google.protos.mmm.v1.model.meridian"; +option java_multiple_files = true; + +// Represents Tensorflow statistical distributions that are used in user priors +// in a Meridian model. +// All fields are required unless otherwise specified. +// See: https://www.tensorflow.org/probability/api_docs/python/tfp/distributions +message Distribution { + // Represents Tensorflow bijectors. + // All fields are required unless otherwise specified. + // See: https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors + message Bijector { + reserved 1; + + // A bijector that shifts the input by a scalar. + message Shift { + // The shift to apply to the input. + repeated double shifts = 1; + } + + // A bijector that scales the input by a scalar or log scale. + message Scale { + // The scale to apply to the input. Should be not be set if `log_scales` + // is set. + repeated double scales = 1; + + // The log scale to apply to the input. Should be not be set if `scales` + // is set. + repeated double log_scales = 2; + } + + // A bijector that reciprocates the input. + message Reciprocal {} + + // The name of this bijector. + string name = 2 [features.field_presence = EXPLICIT]; + + oneof bijector_type { + Shift shift = 3; + Scale scale = 4; + Reciprocal reciprocal = 5; + } + } + + // The following message types are the distribution types that Meridian + // supports. + + // A distribution that broadcasts an underlying distribution's batch shape. + message BatchBroadcast { + // The underlying (pre-broadcast) distribution. + Distribution distribution = 1; + + // The shape of the broadcast distribution. + tensorflow.TensorShapeProto batch_shape = 2; + } + + // A distribution that is transformed by a bijector. + message Transformed { + // The underlying (pre-transformed) distribution. + Distribution distribution = 1; + + // The transforming bijector. + Bijector bijector = 2; + } + + // A scalar deterministic distribution on the real line. + message Deterministic { + // The batch of points on which this distribution is supported. + repeated double locs = 1; + } + + // A half-normal distribution with scales. + message HalfNormal { + // The scales of the distribution. + repeated double scales = 1; + } + + // A log-normal distribution with locs (means) and scales (stddevs). + message LogNormal { + // The means of the underlying Normal distribution. + repeated double locs = 1; + + // The standard deviations of the underlying normal distribution. + repeated double scales = 2; + } + + // A normal distribution with locs (means) and scales (stddevs). + message Normal { + // The means of the underlying Normal distribution. + repeated double locs = 1; + + // The standard deviations of the underlying normal distribution. + // Must contain only positive values. + repeated double scales = 2; + } + + // A truncated Normal distribution, bounded between `low` and `high`. + message TruncatedNormal { + // The means of the underlying Normal distribution. + repeated double locs = 1; + + // The standard deviations of the underlying Normal distribution. + repeated double scales = 2; + + // Lower bound of the distribution's support. Must be less than `high`. + double low = 3 [features.field_presence = EXPLICIT, deprecated = true]; + + // Upper bound of the distribution's support. Must be greater than `low`. + double high = 4 [features.field_presence = EXPLICIT, deprecated = true]; + + // Lower bounds of the distribution's support. Each value in `lows` must be + // less than the corresponding value in `highs`. + repeated double lows = 5; + + // Upper bounds of the distribution's support. Each value in `highs` must be + // greater than the corresponding value in `lows`. + repeated double highs = 6; + } + + // A uniform distribution on the real line. + message Uniform { + // Lower boundary of the output interval. Must be less than `high`. + double low = 1 [features.field_presence = EXPLICIT, deprecated = true]; + + // Upper boundary of the output interval. Must be greater than `low`. + double high = 2 [features.field_presence = EXPLICIT, deprecated = true]; + + // Lower boundaries of the output interval. Each value in `lows` must be + // less than the corresponding value in `highs`. + repeated double lows = 3; + + // Upper boundaries of the output interval. Each value in `highs` must be + // greater than the corresponding value in `lows`. + repeated double highs = 4; + } + + // A Beta distribution with alpha and beta parameters. + // See: + // https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Beta + message Beta { + // The alpha parameter of the underlying Beta distribution. + repeated double alpha = 1; + + // The beta parameter of the underlying Beta distribution. + repeated double beta = 2; + } + + // The canonical name of this distribution in the Meridian model framework. + string name = 1 [features.field_presence = EXPLICIT]; + + oneof distribution_type { + BatchBroadcast batch_broadcast = 2; + Deterministic deterministic = 3; + HalfNormal half_normal = 4; + LogNormal log_normal = 5; + Normal normal = 6; + Transformed transformed = 7; + TruncatedNormal truncated_normal = 8; + Uniform uniform = 9; + Beta beta = 10; + } +} + +// A container for user prior distribution parameters in a Meridian model. +// These distributions are in their mathematical forms when representing +// user priors in the model spec and are part of the user input in a pre-fitted +// model. +// +// After priors sampling, these distributions are broadcast and should _all_ +// contain `Distribution.BatchBroadcast` types. +// +// All parameter distributions are optional. If a distribution is left +// unspecified for a given parameter, Meridian will select its default prior +// distribution. +// +// See: `meridian.model.prior_distribution` module. +// See: +// https://developers.google.com/meridian/docs/advanced-modeling/default-prior-distributions +message PriorDistributions { + Distribution knot_values = 1; + Distribution tau_g_excl_baseline = 2; + Distribution beta_m = 3; + Distribution beta_rf = 4; + Distribution eta_m = 5; + Distribution eta_rf = 6; + Distribution gamma_c = 7; + Distribution xi_c = 8; + Distribution alpha_m = 9; + Distribution alpha_rf = 10; + Distribution ec_m = 11; + Distribution ec_rf = 12; + Distribution slope_m = 13; + Distribution slope_rf = 14; + Distribution sigma = 15; + Distribution roi_m = 16; + Distribution roi_rf = 17; + Distribution mroi_m = 30; + Distribution mroi_rf = 31; + Distribution contribution_m = 32; + Distribution contribution_rf = 33; + Distribution contribution_om = 34; + Distribution contribution_orf = 35; + Distribution contribution_n = 36; + Distribution beta_om = 18; + Distribution beta_orf = 19; + Distribution eta_om = 20; + Distribution eta_orf = 21; + Distribution gamma_n = 22; + Distribution xi_n = 23; + Distribution alpha_om = 24; + Distribution alpha_orf = 25; + Distribution ec_om = 26; + Distribution ec_orf = 27; + Distribution slope_om = 28; + Distribution slope_orf = 29; +} + +// A container for user prior distribution parameters in a Meridian model. +// These distributions are in their mathematical forms when representing +// user priors in the model spec and are part of the user input in a pre-fitted +// model. +// +// After priors sampling, these distributions are broadcast and should _all_ +// contain `Distribution.BatchBroadcast` types. +// +// All parameter distributions are optional. If a distribution is left +// unspecified for a given parameter, Meridian will select its default prior +// distribution. +// +// See: `meridian.model.prior_distribution` module. +// See: +// https://developers.google.com/meridian/docs/advanced-modeling/default-prior-distributions +message PriorTfpDistributions { + TfpDistribution knot_values = 1; + TfpDistribution tau_g_excl_baseline = 2; + TfpDistribution beta_m = 3; + TfpDistribution beta_rf = 4; + TfpDistribution eta_m = 5; + TfpDistribution eta_rf = 6; + TfpDistribution gamma_c = 7; + TfpDistribution xi_c = 8; + TfpDistribution alpha_m = 9; + TfpDistribution alpha_rf = 10; + TfpDistribution ec_m = 11; + TfpDistribution ec_rf = 12; + TfpDistribution slope_m = 13; + TfpDistribution slope_rf = 14; + TfpDistribution sigma = 15; + TfpDistribution roi_m = 16; + TfpDistribution roi_rf = 17; + TfpDistribution mroi_m = 30; + TfpDistribution mroi_rf = 31; + TfpDistribution contribution_m = 32; + TfpDistribution contribution_rf = 33; + TfpDistribution contribution_om = 34; + TfpDistribution contribution_orf = 35; + TfpDistribution contribution_n = 36; + TfpDistribution beta_om = 18; + TfpDistribution beta_orf = 19; + TfpDistribution eta_om = 20; + TfpDistribution eta_orf = 21; + TfpDistribution gamma_n = 22; + TfpDistribution xi_n = 23; + TfpDistribution alpha_om = 24; + TfpDistribution alpha_orf = 25; + TfpDistribution ec_om = 26; + TfpDistribution ec_orf = 27; + TfpDistribution slope_om = 28; + TfpDistribution slope_orf = 29; + + // Lookup table that contains function names mapped to hashed functions used + // by various `tfp.distributions`. + map function_registry = 37; +} + +// Possible distribution types for media random effects across geos. +enum MediaEffectsDistribution { + MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED = 0; + NORMAL = 1; + LOG_NORMAL = 2; +} + +// Possible paid media prior types. +enum PaidMediaPriorType { + PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED = 0; + ROI = 1; + MROI = 2; + COEFFICIENT = 3; + CONTRIBUTION = 4; +} + +// Possible non-paid treatments prior types. +enum NonPaidTreatmentsPriorType { + NON_PAID_TREATMENTS_PRIOR_TYPE_UNSPECIFIED = 0; + NON_PAID_TREATMENTS_PRIOR_TYPE_COEFFICIENT = 1; + NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION = 2; +} + +// Specifies the adstock decay function for each channel. +message AdstockDecayByChannel { + // A map where keys are channel names and values are the adstock decay + // function to use for that channel. Allowed values are 'geometric' or + // 'binomial'. + map channel_decays = 1; +} + +// Hyperparameters for the MMM model. +message Hyperparameters { + reserved 5; + + // Specifies the distribution of media random effects across geos. + // This hyperparameter is ignored in a national-level model. + MediaEffectsDistribution media_effects_dist = 1 + [features.field_presence = EXPLICIT]; + + // Indicates whether to apply the Hill function before Adstock function. + // This hyperparameter does not apply to RF channels. + bool hill_before_adstock = 2 [features.field_presence = EXPLICIT]; + + // The maximum number of lag periods (>= 0) to include in the Adstock + // calculation. If unset, then max lag is interpreted as infinite. + uint32 max_lag = 3 [features.field_presence = EXPLICIT]; + + // Indicates whether to use a unique residual variance for each geo. + // If False, then a single residual variance is used for all geos. + bool unique_sigma_for_each_geo = 4 [features.field_presence = EXPLICIT]; + + // Prior type for the media coefficients. If `paid_media_prior_type` is + // 'coefficient'`, then the model uses `beta_[m|rf]` distributions in the + // priors. If `paid_media_prior_type' is `'roi'` or `'mroi'`, then the + // `roi_[m|rf]` are used. + // + // Deprecated. Use `media_prior_type` and `rf_prior_type` instead. + PaidMediaPriorType paid_media_prior_type = 13 + [features.field_presence = EXPLICIT, deprecated = true]; + + // Prior type for the (paid, non-rf) media coefficients. If `media_prior_type` + // is 'coefficient'`, then the model uses `beta_m` distribution in the priors. + // If `media_prior_type' is `'roi'` or `'mroi'`, then the `roi_m` or `mroi_m` + // is used, respectively. If `media_prior_type` is `'contribution'`, then the + // `contribution_m` is used. + PaidMediaPriorType media_prior_type = 17 [features.field_presence = EXPLICIT]; + + // Prior type for the (paid) rf coefficients. If `rf_prior_type` is + // 'coefficient'`, then the model uses `beta_rf` distribution in the priors. + // If `rfprior_type' is `'roi'` or `'mroi'`, then the `roi_rf` or `mroi_rf` is + // used, respectively. If `rf_prior_type` is `'contribution'`, then the + // `contribution_rf` is used. + PaidMediaPriorType rf_prior_type = 18 [features.field_presence = EXPLICIT]; + + // Prior type for the organic media coefficients. If + // `organic_media_prior_type` is 'coefficient'`, then the model uses `beta_om` + // distribution in the priors. If + // `organic_media_prior_type` is `'contribution'`, then the `contribution_om` + // is used. + NonPaidTreatmentsPriorType organic_media_prior_type = 19 + [features.field_presence = EXPLICIT]; + + // Prior type for the organic rf coefficients. If `organic_rf_prior_type` is + // 'coefficient'`, then the model uses `beta_orf` distribution in the priors. + // If `organic_rf_prior_type` is `'contribution'`, then the + // `contribution_orf` is used. + NonPaidTreatmentsPriorType organic_rf_prior_type = 20 + [features.field_presence = EXPLICIT]; + + // Prior type for the non-media treatments coefficients. If + // `non_media_treatments_prior_type` is 'coefficient'`, then the model uses + // `gamma_n` distribution in the priors. If + // `non_media_treatments_prior_type` is `'contribution'`, then the + // `contribution_n` is used. + NonPaidTreatmentsPriorType non_media_treatments_prior_type = 21 + [features.field_presence = EXPLICIT]; + + // A boolean tensor in the shape `(n_media_times, n_media_channels)`. + // This indicates the subset of `time` coordinates in the model for media ROI + // calibration. If unset, all time coordinates are used for media ROI + // calibration. + tensorflow.TensorProto roi_calibration_period = 9; + + // A boolean tensor in the shape `(n_media_times, n_rf_channels)`. + // This indicates the subset of `time` coordinates in the model for reach and + // frequency ROI calibration. If unset, all time coordinates are used for R&F + // ROI calibration. + tensorflow.TensorProto rf_roi_calibration_period = 10; + + // A (single-value) integer or a list of integers, indicating the knots used + // to estimate time effects. + // If provided as a list of integers, its indices correspond to the indices of + // the time coordinates in the model. + // If provided as a single integer, then there are knots with locations + // equally spaced across time periods. + // If unset, then the number of knots used is equal to the number of time + // periods in the case of a geo model (i.e. each time period has its own + // regression coefficient). If unset in a national model, then the model + // uses `1` as the number of knots. + repeated int32 knots = 6; + + // A boolean indicating whether to use the Automatic Knot Selection + // algorithm to select optimal number of knots for running the model instead + // of the default 1 for national and n_times for non-national models. If this + // is set to true and the knots arg is provided, then an error will be + // raised when deserialized back to ModelSpec. Default: `False`. + bool enable_aks = 22 [features.field_presence = EXPLICIT]; + + // The baseline geo is treated as the reference geo in the encoding of geos. + // If neither option is set, the model will use the geo with the biggest + // population as the baseline geo. + oneof baseline_geo_oneof { + // Deprecated. Use `baseline_geo_int` instead. + double baseline_geo_int_deprecated = 7 [deprecated = true]; + + // Integer representation of baseline geo. + int32 baseline_geo_int = 15; + string baseline_geo_string = 8; + } + + oneof holdout_spec { + // A boolean tensor in the shape `(n_geos, n_times)` for a geo-level model + // or `(n_times,)` for a national model. + // This indicates which observations are part of the holdout sample, which + // are excluded from the training sample. For more details on the holdout + // sample, see: `meridian.model.spec.ModelSpec` + tensorflow.TensorProto holdout_id = 11; + + // The ratio of holdout data to use for the goodness of fit check. Used as + // inputs and this will be implemented in training module right before + // feeding into Meridian. + // The holdout only applied to date as of Q1 2025. + double holdout_ratio = 16; + } + + // A boolean tensor in the shape `(n_controls,)`. + // This indicates the control variables for which the control value will be + // scaled by population. + // If unset, no control variables are scaled by population. + tensorflow.TensorProto control_population_scaling_id = 12; + + // A boolean tensor in the shape `(n_non_media_channels,)`. + // This indicates the non-media treatments channels for which the value will + // be scaled by population. + // If unset, no non-media treatments channels are scaled by population. + tensorflow.TensorProto non_media_population_scaling_id = 14; + + // Specifies the adstock decay function for each media, RF, organic media and + // organic RF channel. Default is 'geometric'. + oneof adstock_decay_spec { + // The global adstock decay function to use for all channels. Allowed + // values are 'geometric' or 'binomial'. + string global_adstock_decay = 23; + + // Channel-specific adstock decay functions. Defaults to 'geometric' for + // channels not specified in the map. + AdstockDecayByChannel adstock_decay_by_channel = 24; + } +} + +// A named tensor parameter. +message Parameter { + string name = 1 [features.field_presence = EXPLICIT]; + + tensorflow.TensorProto tensor = 2; +} + +// InferenceData data contains none, only prior, or both prior and posterior +// sampled parameters and their sampling states and trace from fitting the +// model. +// https://python.arviz.org/en/stable/api/generated/arviz.InferenceData.html +// +// All fields inside this container are `xarray.Dataset`s that are +// byte-serialized in NetCDF format. +// See: https://docs.xarray.dev/en/stable/user-guide/io.html +message InferenceData { + reserved 1; + + // Sampled prior parameters as an `xarray.Dataset` serialized in NetCDF4 + // format. + bytes prior = 2 [features.field_presence = EXPLICIT]; + + // Sampled posterior parameters as an `xarray.Dataset` serialized in NetCDF4 + // format. + bytes posterior = 3 [features.field_presence = EXPLICIT]; + + // Contains "sample_stats", "trace", and other auxiliary data that are useful + // for debugging. + // "sample_stats" and "trace" specifically are available when `posterior` is + // available. + map auxiliary_data = 4; +} + +// The trace of MCMC sampling. +message McmcSamplingTrace { + uint32 num_chains = 1 [features.field_presence = EXPLICIT]; + + uint32 num_draws = 2 [features.field_presence = EXPLICIT]; + + tensorflow.TensorProto step_size = 3; + tensorflow.TensorProto tune = 4; + tensorflow.TensorProto target_log_prob = 5; + tensorflow.TensorProto diverging = 6; + tensorflow.TensorProto accept_ratio = 7; + tensorflow.TensorProto n_steps = 8; + tensorflow.TensorProto is_accepted = 9; +} + +// Diagnostic of MCMC sampling by computing r_hat value for each parameters. +message RHatDiagnostic { + // The r-hat values of model parameters. + // + // Current list of parameters: (see InferenceData.parameters above) + repeated Parameter parameter_r_hats = 1; +} + +message ModelConvergence { + McmcSamplingTrace mcmc_sampling_trace = 1; + + // Convergence heuristic check for the MCMC sampling. + bool convergence = 2 [features.field_presence = EXPLICIT]; + + RHatDiagnostic r_hat_diagnostic = 3; +} + +// Meridian model schema. +message MeridianModel { + reserved 8, 9, 12, 13; + + // The unique identifier of this model. + string model_id = 1 [features.field_presence = EXPLICIT]; + + // The semantic version of the Meridian library used to generate this model. + string model_version = 2 [features.field_presence = EXPLICIT]; + + Hyperparameters hyperparameters = 3; + PriorDistributions prior_distributions = 11 [deprecated = true]; + PriorTfpDistributions prior_tfp_distributions = 18; + + // Tensor properties: scaled input data. + // These tensors are derived from marketing data in the model's input + // (see: `marketing_data.proto`) after they are transformed based on the + // model's spec. A Meridian model can be reconstructed from the marketing + // (input) data and these scaled tensors are not technically required for + // deserialization. + tensorflow.TensorProto media_scaled = 4; + tensorflow.TensorProto reach_scaled = 5; + tensorflow.TensorProto controls_scaled = 6; + tensorflow.TensorProto kpi_scaled = 7; + tensorflow.TensorProto organic_media_scaled = 15; + tensorflow.TensorProto organic_reach_scaled = 16; + tensorflow.TensorProto non_media_treatments_scaled = 17; + + // Inference data contains sampled priors and posteriors. + InferenceData inference_data = 14; + + // Contains the information about model convergence status. + ModelConvergence convergence_info = 10; +} + +// Represents TensorFlow statistical distribution spec that are used in user +// priors in a Meridian model. +// All fields are required unless otherwise specified. +// See: https://www.tensorflow.org/probability/api_docs/python/tfp/distributions +message TfpDistribution { + // A `tfp.distributions` class name. + // e.g. "Normal", "TransformedDistribution", etc. + string distribution_type = 1; + + // Parameters for the specific distribution type. + map parameters = 2; +} + +// Represents a constructor parameter for a `tfp.distributions` class. +message TfpParameterValue { + // For parameter values that are lists or tuples. + message List { + repeated TfpParameterValue values = 1; + } + + // For parameter values that are dicts. + message Dict { + map value_map = 1; + } + + // For parameter values that are functions. + message FunctionParam { + oneof param { + // A key that maps to a custom function in the user-provided function + // registry. The registry allows the model to be serialized without + // including the function's code, enabling a more secure deserialization + // process. + string function_key = 1; + + // Whether the Distribution uses the default function implementation. + bool uses_default = 2; + } + } + + oneof value_type { + // Primitive distribution parameter value types. + float scalar_value = 1; + int32 int_value = 2; + bool bool_value = 3; + string string_value = 4; + bool none_value = 5; + + // For nested distribution parameter (e.g. for `TransformedDistribution`) + TfpDistribution distribution_value = 6; + + // For nested bijector parameter (e.g. for `TransformedDistribution`) + TfpBijector bijector_value = 7; + + // For a parameter that takes a list of parameters. + List list_value = 8; + + // For a parameter that takes a dict. + Dict dict_value = 9; + + // For a parameter that takes a Tensor. + tensorflow.TensorProto tensor_value = 10; + + // Whether the distribution should be fully reparameterized. + // See: + // https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/ReparameterizationType + bool fully_reparameterized = 11; + + // For a parameter that takes a function. + FunctionParam function_param = 12; + } +} + +// Represents a constructor parameter for a `tfp.bijectors` class. +// All fields are required unless otherwise specified. +// See: https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors +message TfpBijector { + // A `tfp.bijectors` class name. + // e.g. "Shift", "Scale", etc. + string bijector_type = 1; + + // Parameters for the specific bijector type. + map parameters = 2; +} diff --git a/proto/mmm/v1/model/mmm_kernel.proto b/proto/mmm/v1/model/mmm_kernel.proto new file mode 100644 index 000000000..5eda371d8 --- /dev/null +++ b/proto/mmm/v1/model/mmm_kernel.proto @@ -0,0 +1,35 @@ +// Copyright 2025 The Meridian Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package mmm.v1.model; + +import "google/protobuf/any.proto"; +import "mmm/v1/marketing/marketing_data.proto"; + +option java_multiple_files = true; +option java_package = "com.google.protos.mmm.v1.model"; + +message MmmKernel { + // The marketing data that is used to train the marketing mix model and + // later analyzed by the model. + marketing.MarketingData marketing_data = 1; + + // The details about the model implementation. + // + // This should contain a trained marketing mix model along with model-specific + // information such as model convergence, flags of model usage. + google.protobuf.Any model = 2; +} diff --git a/proto/pyproject.toml b/proto/pyproject.toml new file mode 100644 index 000000000..aeca39528 --- /dev/null +++ b/proto/pyproject.toml @@ -0,0 +1,66 @@ +[project] +name = "mmm-proto-schema" +description = """\ + This is a private submodule for MMM model representation in protobuf form.\ + """ +readme = "README.md" +requires-python = ">=3.10" +license = {file = "LICENSE"} +authors = [ + {name = "The Meridian Authors", email="no-reply@google.com"}, +] +classifiers = [ # List of https://pypi.org/classifiers/ + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Other/Nonlisted Topic", + "Topic :: Scientific/Engineering :: Information Analysis", +] + +keywords = [ + "mmm", + "protobuf", +] + +dynamic = ["version"] + +dependencies = [ + "protobuf", + "googleapis-common-protos", + "tensorflow", + "google-api-python-client", +] + +[project.urls] +homepage = "https://github.com/google/meridian" +repository = "https://github.com/google/meridian" +changelog = "https://github.com/google/meridian/blob/main/CHANGELOG.md" +documentation = "https://developers.google.com/meridian" + +[build-system] +# Build system specify which backend is used to build/install the project (flit, +# poetry, setuptools,...). All backends are supported by `pip install` +requires = [ + "setuptools >= 61.0.0", + "tomli", +] +build-backend = "setuptools.build_meta" + +[tool.unified_schema_builder] +proto_root="" +github_includes=[ + "https://github.com/googleapis/googleapis.git", + "https://github.com/tensorflow/tensorflow.git", +] + +[tool.setuptools.packages.find] +include = [ + "mmm*", +] +exclude = ["*test"] + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.dynamic] +version = {attr = "__version__"} diff --git a/proto/setup.py b/proto/setup.py new file mode 100644 index 000000000..f6d3222ad --- /dev/null +++ b/proto/setup.py @@ -0,0 +1,124 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The setup.py file for MMM Unified Schema.""" + +import logging +import pathlib +import subprocess +import sys +import tempfile +import tomllib + +import setuptools +from setuptools.command import build +import tomli + + +def _toml_load(path): + if sys.version_info[:2] >= (3, 11): + return tomllib.load(path) + else: + # for python<3.11 + return tomli.load(path) + + +class ProtoBuild(setuptools.Command): + """Custom command to build proto files.""" + + def initialize_options(self): + with open("pyproject.toml", "rb") as f: + cfg = _toml_load(f).get("tool", {}).get("unified_schema_builder") + self._root = pathlib.Path(*cfg.get("proto_root").split("/")) + self._deps = { + url.split("/")[-1].split(".")[0]: url + for url in cfg.get("github_includes") + } + self._srcs = list(self._root.rglob("*.proto")) + + def finalize_options(self): + pass + + def _check_protoc_version(self): + out = subprocess.run( + "protoc --version".split(), check=True, text=True, capture_output=True + ).stdout + out = out.strip() if out else "" + if out.startswith("libprotoc"): + return int(out.split()[1].split(".")[0]) + return 0 + + def _run_cmds(self, commands): + for c in commands: + cmd_str = " ".join(c) + logging.info("Running command %s", cmd_str) + try: + subprocess.run(c, capture_output=True, text=True, check=True) + except subprocess.CalledProcessError as e: + logging.error( + "Skipping Unified Schema compilation since command %s failed:\n%s", + cmd_str, + e.stderr.strip(), + ) + return e.returncode + return 0 + + def _compile_proto_in_place(self, includes): + i = [f"-I{include_path}" for include_path in includes] + srcs_folders = [src for src in self._srcs] + commands = [ + ["protoc"] + i + f"--python_out=. {src}".split() for src in srcs_folders + ] + return self._run_cmds(commands) + + def _pull_deps(self, root): + cmds = [] + for folder, url in self._deps.items(): + target_path = root / folder + target_path.mkdir(parents=True, exist_ok=True) + cmds.append(f"git clone --quiet {url} {target_path}".split()) + return self._run_cmds(cmds) + + def run(self): + protoc_major_version = self._check_protoc_version() + if protoc_major_version < 27: + logging.error( + "Skipping Unified Schema compilation since the existing compiler" + " version is %s, which is lower than 27", + protoc_major_version, + ) + return + + with tempfile.TemporaryDirectory() as t: + temp_root = pathlib.Path(t) + if self._pull_deps(temp_root): + return + includes = [self._root] + [temp_root / path for path in self._deps.keys()] + if self._compile_proto_in_place(includes): + return + + +class CustomBuild(build.build): + sub_commands = [ + ("compile_unified_schema_proto", None) + ] + build.build.sub_commands + + +if __name__ == "__main__": + setuptools.setup( + cmdclass={ + "build": CustomBuild, + "compile_unified_schema_proto": ProtoBuild, + } + ) diff --git a/pyproject.toml b/pyproject.toml index 019e67b00..e3cbbd089 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,13 @@ and-cuda = [ # Installed through `pip install -e .[mlflow]` mlflow = ["mlflow"] +# MMM schema, only works with `pip install .[schema]` +schema = [ + # TODO: publish a schema package to pypi + "mmm-proto-schema @ file:./proto", + "semver", +] + # JAX backend dependencies. # Installed through `pip install -e .[jax]` jax = [ @@ -103,7 +110,11 @@ build-backend = "setuptools.build_meta" include-package-data = true [tool.setuptools.packages.find] -include = ["meridian*"] +include = [ + "meridian*", + "scenarioplanner*", + "schema*", +] exclude = ["*test"] [tool.setuptools.dynamic] diff --git a/scenarioplanner/__init__.py b/scenarioplanner/__init__.py new file mode 100644 index 000000000..3720c207b --- /dev/null +++ b/scenarioplanner/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generates Meridian Scenario Planner Dashboards in Looker Studio. + +This package provides tools to create and manage Meridian dashboards. It helps +transform data from the MMM (Marketing Mix Modeling) schema into a custom +Looker Studio dashboard, which can be shared via a URL. + +The typical workflow is: + + 1. Analyze MMM data into the appropriate schema. + 2. Generate UI-specific proto messages from this data using + `mmm_ui_proto_generator`. + 3. Build a Looker Studio URL that embeds this UI proto data using + `linkingapi`. + +Key functionalities include: + + - `linkingapi`: Builds Looker Studio report URLs with embedded data sources. + This allows for the creation of pre-configured reports. + - `mmm_ui_proto_generator`: Generates a `Mmm` proto message for the Meridian + Scenario Planner UI. It takes structured MMM data and transforms it into the + specific proto format that the dashboard frontend expects. + - `converters`: Provides utilities to convert and transform analyzed model + data into a data format that Looker Studio expects. +""" + +from scenarioplanner import converters +from scenarioplanner import linkingapi +from scenarioplanner import mmm_ui_proto_generator diff --git a/scenarioplanner/converters/__init__.py b/scenarioplanner/converters/__init__.py new file mode 100644 index 000000000..15b58432f --- /dev/null +++ b/scenarioplanner/converters/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides tools for converting and wrapping MMM schema data. + +This package contains modules to transform Marketing Mix Modeling (MMM) protocol +buffer data into other formats and provides high-level wrappers for easier data +manipulation, analysis, and reporting. +""" + +from scenarioplanner.converters import dataframe +from scenarioplanner.converters import mmm +from scenarioplanner.converters import mmm_converter +from scenarioplanner.converters import sheets diff --git a/scenarioplanner/converters/dataframe/__init__.py b/scenarioplanner/converters/dataframe/__init__.py new file mode 100644 index 000000000..a6a3673da --- /dev/null +++ b/scenarioplanner/converters/dataframe/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Converters for `Mmm` protos to flat dataframes. + +This package provides a set of tools for transforming data from `Mmm` +protos into flat dataframes. This conversion makes the data easier to analyze, +visualize, and use in other data processing pipelines. +""" + +from scenarioplanner.converters.dataframe import budget_opt_converters +from scenarioplanner.converters.dataframe import common +from scenarioplanner.converters.dataframe import constants +from scenarioplanner.converters.dataframe import converter +from scenarioplanner.converters.dataframe import dataframe_model_converter +from scenarioplanner.converters.dataframe import marketing_analyses_converters +from scenarioplanner.converters.dataframe import rf_opt_converters diff --git a/scenarioplanner/converters/dataframe/budget_opt_converters.py b/scenarioplanner/converters/dataframe/budget_opt_converters.py new file mode 100644 index 000000000..239e84273 --- /dev/null +++ b/scenarioplanner/converters/dataframe/budget_opt_converters.py @@ -0,0 +1,383 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Budget optimization converters. + +This module defines various classes that convert `BudgetOptimizationResult`s +into flat dataframes. +""" + +import abc +from collections.abc import Iterator, Sequence + +from meridian import constants as c +from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb +from mmm.v1.marketing.optimization import constraints_pb2 as constraints_pb +from scenarioplanner.converters import mmm +from scenarioplanner.converters.dataframe import common +from scenarioplanner.converters.dataframe import constants as dc +from scenarioplanner.converters.dataframe import converter +import pandas as pd + + +__all__ = [ + "NamedOptimizationGridConverter", + "BudgetOptimizationSpecsConverter", + "BudgetOptimizationResultsConverter", + "BudgetOptimizationResponseCurvesConverter", +] + + +class _BudgetOptimizationConverter(converter.Converter, abc.ABC): + """An abstract class for dealing with `BudgetOptimizationResult`s.""" + + def __call__(self) -> Iterator[tuple[str, pd.DataFrame]]: + results = self._mmm.budget_optimization_results + if not results: + return + + # Validate group IDs. + group_ids = [result.group_id for result in results if result.group_id] + if len(set(group_ids)) != len(group_ids): + raise ValueError( + "Specified group_id must be unique or unset among the given group of" + " results." + ) + + yield from self._handle_budget_optimization_results( + self._mmm.budget_optimization_results + ) + + def _handle_budget_optimization_results( + self, results: Sequence[mmm.BudgetOptimizationResult] + ) -> Iterator[tuple[str, pd.DataFrame]]: + raise NotImplementedError() + + +class NamedOptimizationGridConverter(_BudgetOptimizationConverter): + """Outputs named tables for budget optimization grids. + + When called, this converter returns a data frame with the columns: + + * "Group ID" + A UUID generated for each named incremental outcome grid. + * "Channel" + * "Spend" + * "Incremental Outcome" + For each named budget optimization result in the MMM output proto. + """ + + def _handle_budget_optimization_results( + self, results: Sequence[mmm.BudgetOptimizationResult] + ) -> Iterator[tuple[str, pd.DataFrame]]: + for budget_opt_result in results: + # There should be one unique ID for each result. + group_id = ( + str(budget_opt_result.group_id) if budget_opt_result.group_id else "" + ) + grid = budget_opt_result.incremental_outcome_grid + + # Each grid yields its own data frame table. + optimization_grid_data = [] + for channel, cells in grid.channel_spend_grids.items(): + for spend, incremental_outcome in cells: + optimization_grid_data.append([ + group_id, + channel, + spend, + incremental_outcome, + ]) + + yield ( + common.create_grid_sheet_name( + dc.OPTIMIZATION_GRID_NAME_PREFIX, grid.name + ), + pd.DataFrame( + optimization_grid_data, + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.OPTIMIZATION_GRID_SPEND_COLUMN, + dc.OPTIMIZATION_GRID_INCREMENTAL_OUTCOME_COLUMN, + ], + ), + ) + + +class BudgetOptimizationSpecsConverter(_BudgetOptimizationConverter): + """Outputs a table of budget optimization specs. + + When called, this converter returns a data frame with the columns: + + * "Group ID" + A UUID generated for an incremental outcome grid present in the output. + * "Date Interval Start" + * "Date Interval End" + * "Analysis Period" + * "Objective" + * "Scenario Type" + * "Initial Channel Spend" + * "Target Metric Constraint" + None if scenario type is "Fixed" + * "Target Metric Value" + None if scenario type is "Fixed" + * "Channel" + * "Channel Spend Min" + * "Channel Spend Max" + """ + + def _handle_budget_optimization_results( + self, results: Sequence[mmm.BudgetOptimizationResult] + ) -> Iterator[tuple[str, pd.DataFrame]]: + spec_data = [] + for budget_opt_result in results: + # There should be one unique ID for each result. + group_id = ( + str(budget_opt_result.group_id) if budget_opt_result.group_id else "" + ) + spec = budget_opt_result.spec + + objective = common.map_target_metric_str(spec.objective) + # These are the start and end dates for the requested budget optimization + # in this spec. + date_interval_start, date_interval_end = ( + d.strftime(c.DATE_FORMAT) for d in spec.date_interval.date_interval + ) + budget_date_interval = (date_interval_start, date_interval_end) + + # aka historical spend from marketing data in the model kernel + initial_channel_spends = self._mmm.marketing_data.all_channel_spends( + budget_date_interval + ) + + scenario = ( + dc.OPTIMIZATION_SPEC_SCENARIO_FIXED + if spec.is_fixed_scenario + else dc.OPTIMIZATION_SPEC_SCENARIO_FLEXIBLE + ) + + if spec.is_fixed_scenario: + target_metric_constraint = None + target_metric_value = None + else: + flexible_scenario = ( + spec.budget_optimization_spec_proto.flexible_budget_scenario + ) + # Meridian flexible budget spec only has one target metric constraint. + target_metric_constraint_pb = ( + flexible_scenario.target_metric_constraints[0] + ) + target_metric_constraint = common.map_target_metric_str( + target_metric_constraint_pb.target_metric + ) + target_metric_value = target_metric_constraint_pb.target_value + + # When the constraint of a channel is not specified, that channel will + # have a constraint of `[0, max_budget]` which is equivalent to no + # constraint. + # + # Here, `max_budget` is the total budget for a fixed scenario spec, or the + # max budget upper bound for a flexible scenario spec. + # + # NOTE: This assumption must be in line with what the budget optimization + # processor does with an empty channel constraints list. + channel_constraints = spec.channel_constraints + if not channel_constraints: + # Implicit channel constraints; synthesize them first before proceeding. + channel_constraints = [ + budget_pb.ChannelConstraint( + channel_name=channel_name, + budget_constraint=constraints_pb.BudgetConstraint( + min_budget=0.0, + max_budget=spec.max_budget, + ), + ) + for channel_name in self._mmm.marketing_data.media_channels + ] + + for channel_constraint in channel_constraints: + spec_data.append([ + group_id, + date_interval_start, + date_interval_end, + spec.date_interval_tag, + objective, + scenario, + initial_channel_spends.get(channel_constraint.channel_name, 0.0), + target_metric_constraint, + target_metric_value, + channel_constraint.channel_name, + channel_constraint.budget_constraint.min_budget, + channel_constraint.budget_constraint.max_budget, + ]) + + yield ( + dc.OPTIMIZATION_SPECS, + pd.DataFrame( + spec_data, + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_SPEC_DATE_INTERVAL_START_COLUMN, + dc.OPTIMIZATION_SPEC_DATE_INTERVAL_END_COLUMN, + dc.ANALYSIS_PERIOD_COLUMN, + dc.OPTIMIZATION_SPEC_OBJECTIVE_COLUMN, + dc.OPTIMIZATION_SPEC_SCENARIO_TYPE_COLUMN, + dc.OPTIMIZATION_SPEC_INITIAL_CHANNEL_SPEND_COLUMN, + dc.OPTIMIZATION_SPEC_TARGET_METRIC_CONSTRAINT_COLUMN, + dc.OPTIMIZATION_SPEC_TARGET_METRIC_VALUE_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.OPTIMIZATION_SPEC_CHANNEL_SPEND_MIN_COLUMN, + dc.OPTIMIZATION_SPEC_CHANNEL_SPEND_MAX_COLUMN, + ], + ), + ) + + +class BudgetOptimizationResultsConverter(_BudgetOptimizationConverter): + """Outputs a table of budget optimization results objectives. + + When called, this converter returns a data frame with the columns: + + * "Group ID" + A UUID generated for a budget optimization result present in the output. + * "Channel" + * "Is Revenue KPI" + Whether the KPI is revenue or not. + * "Optimal Spend" + * "Optimal Spend Share" + * "Optimal Impression Effectiveness" + * "Optimal ROI" + * "Optimal mROI" + * "Optimal CPC" + """ + + def _handle_budget_optimization_results( + self, results: Sequence[mmm.BudgetOptimizationResult] + ) -> Iterator[tuple[str, pd.DataFrame]]: + data = [] + + for budget_opt_result in results: + group_id = ( + str(budget_opt_result.group_id) if budget_opt_result.group_id else "" + ) + marketing_analysis = budget_opt_result.optimized_marketing_analysis + + media_channel_analyses = marketing_analysis.channel_mapped_media_analyses + for channel, media_analysis in media_channel_analyses.items(): + # Skip "All Channels" pseudo-channel. + if channel == c.ALL_CHANNELS: + continue + + spend = media_analysis.spend_info_pb.spend + spend_share = media_analysis.spend_info_pb.spend_share + + revenue_outcome = media_analysis.maybe_revenue_outcome + nonrevenue_outcome = media_analysis.maybe_non_revenue_outcome + + # pylint: disable=cell-var-from-loop + def _append_outcome_data( + outcome: mmm.Outcome | None, + is_revenue_kpi: bool, + ) -> None: + if outcome is None: + return + effectiveness = outcome.effectiveness_pb.value.value + roi = outcome.roi_pb.value + mroi = outcome.marginal_roi_pb.value + cpc = outcome.cost_per_contribution_pb.value + data.append([ + group_id, + channel, + is_revenue_kpi, + spend, + spend_share, + effectiveness, + roi, + mroi, + cpc, + ]) + + _append_outcome_data(revenue_outcome, True) + _append_outcome_data(nonrevenue_outcome, False) + # pylint: enable=cell-var-from-loop + + yield ( + dc.OPTIMIZATION_RESULTS, + pd.DataFrame( + data, + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.OPTIMIZATION_RESULT_IS_REVENUE_KPI_COLUMN, + dc.OPTIMIZATION_RESULT_SPEND_COLUMN, + dc.OPTIMIZATION_RESULT_SPEND_SHARE_COLUMN, + dc.OPTIMIZATION_RESULT_EFFECTIVENESS_COLUMN, + dc.OPTIMIZATION_RESULT_ROI_COLUMN, + dc.OPTIMIZATION_RESULT_MROI_COLUMN, + dc.OPTIMIZATION_RESULT_CPC_COLUMN, + ], + ), + ) + + +class BudgetOptimizationResponseCurvesConverter(_BudgetOptimizationConverter): + """Outputs a table of budget optimization response curves. + + When called, this converter returns a data frame with the columns: + + * "Group ID" + A UUID generated for a budget optimization result present in the output. + * "Channel" + * "Spend" + * "Incremental Outcome" + """ + + def _handle_budget_optimization_results( + self, results: Sequence[mmm.BudgetOptimizationResult] + ) -> Iterator[tuple[str, pd.DataFrame]]: + response_curve_data = [] + for budget_opt_result in results: + group_id = ( + str(budget_opt_result.group_id) if budget_opt_result.group_id else "" + ) + curves = budget_opt_result.response_curves + for curve in curves: + for spend, incremental_outcome in curve.response_points: + response_curve_data.append([ + group_id, + curve.channel_name, + spend, + incremental_outcome, + ]) + + yield ( + dc.OPTIMIZATION_RESPONSE_CURVES, + pd.DataFrame( + response_curve_data, + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.OPTIMIZATION_GRID_SPEND_COLUMN, + dc.OPTIMIZATION_GRID_INCREMENTAL_OUTCOME_COLUMN, + ], + ), + ) + + +CONVERTERS = [ + NamedOptimizationGridConverter, + BudgetOptimizationSpecsConverter, + BudgetOptimizationResultsConverter, + BudgetOptimizationResponseCurvesConverter, +] diff --git a/scenarioplanner/converters/dataframe/budget_opt_converters_test.py b/scenarioplanner/converters/dataframe/budget_opt_converters_test.py new file mode 100644 index 000000000..069619e98 --- /dev/null +++ b/scenarioplanner/converters/dataframe/budget_opt_converters_test.py @@ -0,0 +1,485 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from mmm.v1 import mmm_pb2 as mmm_pb +from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb +from mmm.v1.marketing.optimization import marketing_optimization_pb2 as optimization_pb +from mmm.v1.model import mmm_kernel_pb2 as kernel_pb +from scenarioplanner.converters import mmm +from scenarioplanner.converters import test_data as td +from scenarioplanner.converters.dataframe import budget_opt_converters as converters +from scenarioplanner.converters.dataframe import constants as dc +import pandas as pd + + +mock = absltest.mock + + +_DEFAULT_MMM_PROTO = mmm_pb.Mmm( + mmm_kernel=kernel_pb.MmmKernel( + marketing_data=td.MARKETING_DATA, + ), + marketing_optimization=optimization_pb.MarketingOptimization( + budget_optimization=budget_pb.BudgetOptimization( + results=[ + td.BUDGET_OPTIMIZATION_RESULT_FIXED_BOTH_OUTCOMES, + td.BUDGET_OPTIMIZATION_RESULT_FLEX_NONREV, + ] + ), + ), +) + +_GID1 = td.BUDGET_OPTIMIZATION_RESULT_FIXED_BOTH_OUTCOMES.group_id +_GID2 = td.BUDGET_OPTIMIZATION_RESULT_FLEX_NONREV.group_id + + +class NamedOptimizationGridConverterTest(absltest.TestCase): + + def test_call_no_results(self): + conv = converters.NamedOptimizationGridConverter( + mmm_wrapper=mmm.Mmm(mmm_pb.Mmm()) + ) + + self.assertEmpty(list(conv())) + + def test_call(self): + conv = converters.NamedOptimizationGridConverter( + mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO) + ) + + dataframes = list(conv()) + + self.assertLen(dataframes, 2) + (foo_grid_name, foo_grid_df), (bar_grid_name, bar_grid_df) = dataframes + + expected_foo_grid_name = "_".join( + [dc.OPTIMIZATION_GRID_NAME_PREFIX, "incremental_outcome_grid_foo"] + ) + expected_bar_grid_name = "_".join( + [dc.OPTIMIZATION_GRID_NAME_PREFIX, "incremental_outcome_grid_bar"] + ) + + self.assertEqual(foo_grid_name, expected_foo_grid_name) + self.assertEqual(bar_grid_name, expected_bar_grid_name) + + expected_columns = [ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.OPTIMIZATION_GRID_SPEND_COLUMN, + dc.OPTIMIZATION_GRID_INCREMENTAL_OUTCOME_COLUMN, + ] + pd.testing.assert_frame_equal( + foo_grid_df, + pd.DataFrame( + [ + [ + _GID1, + "Channel 1", + 10000.0, + 100.0, + ], + [ + _GID1, + "Channel 1", + 20000.0, + 200.0, + ], + [ + _GID1, + "Channel 2", + 10000.0, + 100.0, + ], + [ + _GID1, + "Channel 2", + 20000.0, + 200.0, + ], + ], + columns=expected_columns, + ), + ) + pd.testing.assert_frame_equal( + bar_grid_df, + pd.DataFrame( + [ + [ + _GID2, + "Channel 1", + 1000.0, + 10.0, + ], + [ + _GID2, + "Channel 1", + 2000.0, + 20.0, + ], + [ + _GID2, + "Channel 2", + 1000.0, + 10.0, + ], + [ + _GID2, + "Channel 2", + 2000.0, + 20.0, + ], + ], + columns=expected_columns, + ), + ) + + +class BudgetOptimizationSpecsConverterTest(absltest.TestCase): + + def test_call_no_results(self): + conv = converters.BudgetOptimizationSpecsConverter( + mmm_wrapper=mmm.Mmm(mmm_pb.Mmm()) + ) + + self.assertEmpty(list(conv())) + + def test_call(self): + conv = converters.BudgetOptimizationSpecsConverter( + mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.OPTIMIZATION_SPECS) + # Expect two specs: one for each result. + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + [ + # These corresponds to the fixed spec FOO: + [ + _GID1, + "2024-01-01", + "2024-01-15", + dc.ANALYSIS_TAG_ALL, + dc.OPTIMIZATION_SPEC_TARGET_METRIC_ROI, + dc.OPTIMIZATION_SPEC_SCENARIO_FIXED, + 400.0, + None, + None, + "Channel 1", + 0.0, + 100000.0, + ], + [ + _GID1, + "2024-01-01", + "2024-01-15", + dc.ANALYSIS_TAG_ALL, + dc.OPTIMIZATION_SPEC_TARGET_METRIC_ROI, + dc.OPTIMIZATION_SPEC_SCENARIO_FIXED, + 400.0, + None, + None, + "Channel 2", + 0.0, + 100000.0, + ], + # These corresponds to the flexible spec BAR: + [ + _GID2, + "2024-01-08", + "2024-01-15", + "Week2", + dc.OPTIMIZATION_SPEC_TARGET_METRIC_KPI, + dc.OPTIMIZATION_SPEC_SCENARIO_FLEXIBLE, + 200.0, + dc.OPTIMIZATION_SPEC_TARGET_METRIC_CPIK, + 10.0, + "Channel 1", + 1100.0, + 1500.0, + ], + [ + _GID2, + "2024-01-08", + "2024-01-15", + "Week2", + dc.OPTIMIZATION_SPEC_TARGET_METRIC_KPI, + dc.OPTIMIZATION_SPEC_SCENARIO_FLEXIBLE, + 200.0, + dc.OPTIMIZATION_SPEC_TARGET_METRIC_CPIK, + 10.0, + "Channel 2", + 1000.0, + 1800.0, + ], + ], + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_SPEC_DATE_INTERVAL_START_COLUMN, + dc.OPTIMIZATION_SPEC_DATE_INTERVAL_END_COLUMN, + dc.ANALYSIS_PERIOD_COLUMN, + dc.OPTIMIZATION_SPEC_OBJECTIVE_COLUMN, + dc.OPTIMIZATION_SPEC_SCENARIO_TYPE_COLUMN, + dc.OPTIMIZATION_SPEC_INITIAL_CHANNEL_SPEND_COLUMN, + dc.OPTIMIZATION_SPEC_TARGET_METRIC_CONSTRAINT_COLUMN, + dc.OPTIMIZATION_SPEC_TARGET_METRIC_VALUE_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.OPTIMIZATION_SPEC_CHANNEL_SPEND_MIN_COLUMN, + dc.OPTIMIZATION_SPEC_CHANNEL_SPEND_MAX_COLUMN, + ], + ), + ) + + +class BudgetOptimizationResultsConverterTest(absltest.TestCase): + + def test_call_no_results(self): + conv = converters.BudgetOptimizationResultsConverter( + mmm_wrapper=mmm.Mmm(mmm_pb.Mmm()) + ) + + self.assertEmpty(list(conv())) + + def test_call_duplicate_group_id(self): + mmm_proto = mmm_pb.Mmm() + mmm_proto.CopyFrom(_DEFAULT_MMM_PROTO) + mmm_proto.marketing_optimization.budget_optimization.results.append( + td.BUDGET_OPTIMIZATION_RESULT_FIXED_BOTH_OUTCOMES, + ) + + with self.assertRaisesRegex( + ValueError, "Specified group_id must be unique" + ): + conv = converters.BudgetOptimizationResultsConverter( + mmm_wrapper=mmm.Mmm(mmm_proto) + ) + next(conv()) + + def test_call(self): + conv = converters.BudgetOptimizationResultsConverter( + mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.OPTIMIZATION_RESULTS) + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + [ + [ + _GID1, + "Channel 1", + True, + 75000.0, + 0.5, + 2.2, + 1.0, + 10.0, + 5.0, + ], + [ + _GID1, + "Channel 1", + False, + 75000.0, + 0.5, + 5.5, + 10.0, + 100.0, + 100.0, + ], + [ + _GID1, + "Channel 2", + True, + 25000.0, + (1.0 / 6.0), + 4.4, + 2.0, + 20.0, + 10.0, + ], + [ + _GID1, + "Channel 2", + False, + 25000.0, + (1.0 / 6.0), + 11.0, + 20.0, + 200.0, + 200.0, + ], + [ + _GID1, + "RF Channel 1", + True, + 30000.0, + 0.2, + 2.2, + 1.0, + 10.0, + 5.0, + ], + [ + _GID1, + "RF Channel 1", + False, + 30000.0, + 0.2, + 5.5, + 10.0, + 100.0, + 100.0, + ], + [ + _GID1, + "RF Channel 2", + True, + 20000.0, + (2.0 / 15.0), + 4.4, + 2.0, + 20.0, + 10.0, + ], + [ + _GID1, + "RF Channel 2", + False, + 20000.0, + (2.0 / 15.0), + 11.0, + 20.0, + 200.0, + 200.0, + ], + [ + _GID2, + "Channel 1", + False, + 75000.0, + 0.5, + 6.6, + 12.0, + 120.0, + 120.0, + ], + [ + _GID2, + "Channel 2", + False, + 25000.0, + (1.0 / 6.0), + 12.1, + 22.0, + 220.0, + 220.0, + ], + [ + _GID2, + "RF Channel 1", + False, + 30000.0, + 0.2, + 6.6, + 12.0, + 120.0, + 120.0, + ], + [ + _GID2, + "RF Channel 2", + False, + 20000.0, + (2.0 / 15.0), + 12.1, + 22.0, + 220.0, + 220.0, + ], + ], + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.OPTIMIZATION_RESULT_IS_REVENUE_KPI_COLUMN, + dc.OPTIMIZATION_RESULT_SPEND_COLUMN, + dc.OPTIMIZATION_RESULT_SPEND_SHARE_COLUMN, + dc.OPTIMIZATION_RESULT_EFFECTIVENESS_COLUMN, + dc.OPTIMIZATION_RESULT_ROI_COLUMN, + dc.OPTIMIZATION_RESULT_MROI_COLUMN, + dc.OPTIMIZATION_RESULT_CPC_COLUMN, + ], + ), + ) + + +class BudgetOptimizationResponseCurvesConverterTest(absltest.TestCase): + + def test_call_no_results(self): + conv = converters.BudgetOptimizationResponseCurvesConverter( + mmm_wrapper=mmm.Mmm(mmm_pb.Mmm()) + ) + + self.assertEmpty(list(conv())) + + def test_call(self): + conv = converters.BudgetOptimizationResponseCurvesConverter( + mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.OPTIMIZATION_RESPONSE_CURVES) + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + [ + [_GID1, "Channel 1", 1.0, 100.0], + [_GID1, "Channel 1", 2.0, 200.0], + [_GID1, "Channel 2", 2.0, 200.0], + [_GID1, "Channel 2", 4.0, 400.0], + [_GID1, "RF Channel 1", 1.0, 100.0], + [_GID1, "RF Channel 1", 2.0, 200.0], + [_GID1, "RF Channel 2", 2.0, 200.0], + [_GID1, "RF Channel 2", 4.0, 400.0], + [_GID1, "All Channels", 10.0, 1000.0], + [_GID1, "All Channels", 20.0, 2000.0], + [_GID2, "Channel 1", 1.2, 120.0], + [_GID2, "Channel 1", 2.4, 240.0], + [_GID2, "Channel 2", 2.2, 220.0], + [_GID2, "Channel 2", 4.4, 440.0], + [_GID2, "RF Channel 1", 1.2, 120.0], + [_GID2, "RF Channel 1", 2.4, 240.0], + [_GID2, "RF Channel 2", 2.2, 220.0], + [_GID2, "RF Channel 2", 4.4, 440.0], + [_GID2, "All Channels", 10.0, 1000.0], + [_GID2, "All Channels", 20.0, 2000.0], + ], + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.OPTIMIZATION_GRID_SPEND_COLUMN, + dc.OPTIMIZATION_GRID_INCREMENTAL_OUTCOME_COLUMN, + ], + ), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/scenarioplanner/converters/dataframe/common.py b/scenarioplanner/converters/dataframe/common.py new file mode 100644 index 000000000..69c571b8c --- /dev/null +++ b/scenarioplanner/converters/dataframe/common.py @@ -0,0 +1,71 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utility functions in this package.""" + +import re + +from mmm.v1.common import target_metric_pb2 as target_metric_pb +from scenarioplanner.converters.dataframe import constants as dc + + +def map_target_metric_str(metric: target_metric_pb.TargetMetric) -> str: + """Maps a TargetMetric enum to a string. + + Args: + metric: The TargetMetric enum to map. + + Returns: + The string representation of the TargetMetric enum. + """ + match metric: + case target_metric_pb.TargetMetric.KPI: + return dc.OPTIMIZATION_SPEC_TARGET_METRIC_KPI + case target_metric_pb.TargetMetric.ROI: + return dc.OPTIMIZATION_SPEC_TARGET_METRIC_ROI + case target_metric_pb.TargetMetric.MARGINAL_ROI: + return dc.OPTIMIZATION_SPEC_TARGET_METRIC_MARGINAL_ROI + case target_metric_pb.TargetMetric.COST_PER_INCREMENTAL_KPI: + return dc.OPTIMIZATION_SPEC_TARGET_METRIC_CPIK + case _: + raise ValueError(f"Unsupported target metric: {metric}") + + +def _to_sheet_name_format(s: str) -> str: + """Converts a string to a sheet name format. + + Replace consecutive spaces with a single underscore using regex. + + Args: + s: The string to convert. + + Returns: + The converted sheet name. + """ + return re.sub(r"\s+", dc.SHEET_NAME_DELIMITER, s) + + +def create_grid_sheet_name(prefix: str, grid_name: str) -> str: + """Creates a grid sheet name with the given prefix and grid name. + + Args: + prefix: The prefix of the sheet name. + grid_name: The name of the grid. + + Returns: + The grid sheet name. + """ + grid_sheet_name = _to_sheet_name_format(grid_name) + sheet_prefix = _to_sheet_name_format(prefix) + return f"{sheet_prefix}_{grid_sheet_name}" diff --git a/scenarioplanner/converters/dataframe/common_test.py b/scenarioplanner/converters/dataframe/common_test.py new file mode 100644 index 000000000..225f5d11f --- /dev/null +++ b/scenarioplanner/converters/dataframe/common_test.py @@ -0,0 +1,54 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +from scenarioplanner.converters.dataframe import common + + +class CommonTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name="name_with_underscore", + prefix="prefix", + name="name_with_underscore", + expected="prefix_name_with_underscore", + ), + dict( + testcase_name="name_with_multiple_spaces", + prefix="prefix", + name="name with multiple spaces", + expected="prefix_name_with_multiple_spaces", + ), + dict( + testcase_name="prefix_with_underscore", + prefix="prefix_with_underscore", + name="name", + expected="prefix_with_underscore_name", + ), + dict( + testcase_name="prefix_with_multiple_spaces", + prefix="prefix with multiple spaces", + name="name", + expected="prefix_with_multiple_spaces_name", + ), + ) + def test_create_grid_sheet_name(self, prefix, name, expected): + got = common.create_grid_sheet_name(prefix, name) + self.assertEqual(expected, got) + + +if __name__ == "__main__": + absltest.main() diff --git a/scenarioplanner/converters/dataframe/constants.py b/scenarioplanner/converters/dataframe/constants.py new file mode 100644 index 000000000..941dc4227 --- /dev/null +++ b/scenarioplanner/converters/dataframe/constants.py @@ -0,0 +1,137 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataframe converter constants.""" + +SHEET_NAME_DELIMITER = "_" + +# Special analysis aggregation tags. +ANALYSIS_TAG_ALL = "ALL" + +# ModelFit table column names +MODEL_FIT = "ModelFit" +MODEL_FIT_TIME_COLUMN = "Time" +MODEL_FIT_EXPECTED_CI_LOW_COLUMN = "Expected CI Low" +MODEL_FIT_EXPECTED_CI_HIGH_COLUMN = "Expected CI High" +MODEL_FIT_EXPECTED_COLUMN = "Expected" +MODEL_FIT_BASELINE_COLUMN = "Baseline" +MODEL_FIT_ACTUAL_COLUMN = "Actual" + +# ModelDiagnostics table column names +MODEL_DIAGNOSTICS = "ModelDiagnostics" +MODEL_DIAGNOSTICS_DATASET_COLUMN = "Dataset" +MODEL_DIAGNOSTICS_R_SQUARED_COLUMN = "R Squared" +MODEL_DIAGNOSTICS_MAPE_COLUMN = "MAPE" +MODEL_DIAGNOSTICS_WMAPE_COLUMN = "wMAPE" + +# Common column names +ANALYSIS_PERIOD_COLUMN = "Analysis Period" +ANALYSIS_DATE_START_COLUMN = "Analysis Date Start" +ANALYSIS_DATE_END_COLUMN = "Analysis Date End" + +# MediaOutcome table column names +MEDIA_OUTCOME = "MediaOutcome" +MEDIA_OUTCOME_CHANNEL_INDEX_COLUMN = "Channel Index" +MEDIA_OUTCOME_CHANNEL_COLUMN = "Channel" +MEDIA_OUTCOME_INCREMENTAL_OUTCOME_COLUMN = "Incremental Outcome" +MEDIA_OUTCOME_CONTRIBUTION_SHARE_COLUMN = "Contribution Share" +MEDIA_OUTCOME_BASELINE_PSEUDO_CHANNEL_INDEX = 0 +MEDIA_OUTCOME_ALL_CHANNELS_PSEUDO_CHANNEL_INDEX = 1 +MEDIA_OUTCOME_CHANNEL_INDEX = 2 + +# MediaSpend table column names +MEDIA_SPEND = "MediaSpend" +MEDIA_SPEND_CHANNEL_COLUMN = "Channel" +MEDIA_SPEND_SHARE_VALUE_COLUMN = "Share Value" +MEDIA_SPEND_LABEL_COLUMN = "Label" +# The "Label" column enums +MEDIA_SPEND_LABEL_SPEND_SHARE = "Spend Share" +MEDIA_SPEND_LABEL_REVENUE_SHARE = "Revenue Share" +MEDIA_SPEND_LABEL_KPI_SHARE = "KPI Share" + +# MediaROI table column names +MEDIA_ROI = "MediaROI" +MEDIA_ROI_CHANNEL_COLUMN = "Channel" +MEDIA_ROI_SPEND_COLUMN = "Spend" +MEDIA_ROI_EFFECTIVENESS_COLUMN = "Effectiveness" +MEDIA_ROI_ROI_COLUMN = "ROI" +MEDIA_ROI_ROI_CI_LOW_COLUMN = "ROI CI Low" +MEDIA_ROI_ROI_CI_HIGH_COLUMN = "ROI CI High" +MEDIA_ROI_MARGINAL_ROI_COLUMN = "Marginal ROI" +MEDIA_ROI_IS_REVENUE_KPI_COLUMN = "Is Revenue KPI" + + +# Shared column names among Optimization tables +OPTIMIZATION_GROUP_ID_COLUMN = "Group ID" +OPTIMIZATION_CHANNEL_COLUMN = "Channel" + +# Optimization grid table column names +# (Table name is user-generated from the spec) +OPTIMIZATION_GRID_SPEND_COLUMN = "Spend" +OPTIMIZATION_GRID_INCREMENTAL_OUTCOME_COLUMN = "Incremental Outcome" + +# R&F Optimization grid table column names +# (Table name is user-generated from the spec) +RF_OPTIMIZATION_GRID_FREQ_COLUMN = "Frequency" +RF_OPTIMIZATION_GRID_ROI_OUTCOME_COLUMN = "ROI" + +# Budget optimization grid table name +OPTIMIZATION_GRID_NAME_PREFIX = "budget_opt_grid" + +# R&F optimization grid table name +RF_OPTIMIZATION_GRID_NAME_PREFIX = "rf_opt_grid" + +# Optimization spec table column names and enum values +OPTIMIZATION_SPECS = "budget_opt_specs" +OPTIMIZATION_SPEC_DATE_INTERVAL_START_COLUMN = "Date Interval Start" +OPTIMIZATION_SPEC_DATE_INTERVAL_END_COLUMN = "Date Interval End" +OPTIMIZATION_SPEC_OBJECTIVE_COLUMN = "Objective" +OPTIMIZATION_SPEC_SCENARIO_TYPE_COLUMN = "Scenario Type" +OPTIMIZATION_SPEC_SCENARIO_FIXED = "Fixed" +OPTIMIZATION_SPEC_SCENARIO_FLEXIBLE = "Flexible" +OPTIMIZATION_SPEC_INITIAL_CHANNEL_SPEND_COLUMN = "Initial Channel Spend" +OPTIMIZATION_SPEC_TARGET_METRIC_CONSTRAINT_COLUMN = "Target Metric Constraint" +OPTIMIZATION_SPEC_TARGET_METRIC_KPI = "KPI" +OPTIMIZATION_SPEC_TARGET_METRIC_ROI = "ROI" +OPTIMIZATION_SPEC_TARGET_METRIC_MARGINAL_ROI = "Marginal ROI" +OPTIMIZATION_SPEC_TARGET_METRIC_CPIK = "Cost per Incremental KPI" +OPTIMIZATION_SPEC_TARGET_METRIC_VALUE_COLUMN = "Target Metric Value" +OPTIMIZATION_SPEC_CHANNEL_COLUMN = "Channel" +OPTIMIZATION_SPEC_CHANNEL_SPEND_MIN_COLUMN = "Channel Spend Min" +OPTIMIZATION_SPEC_CHANNEL_SPEND_MAX_COLUMN = "Channel Spend Max" + +# R&F Optimization spec table column names and enum values +RF_OPTIMIZATION_SPECS = "rf_opt_specs" +RF_OPTIMIZATION_SPEC_CHANNEL_FREQUENCY_MIN_COLUMN = "Channel Frequency Min" +RF_OPTIMIZATION_SPEC_CHANNEL_FREQUENCY_MAX_COLUMN = "Channel Frequency Max" + +# Optimization results table column names +OPTIMIZATION_RESULTS = "budget_opt_results" +OPTIMIZATION_RESULT_SPEND_COLUMN = "Optimal Spend" +OPTIMIZATION_RESULT_SPEND_SHARE_COLUMN = "Optimal Spend Share" +OPTIMIZATION_RESULT_EFFECTIVENESS_COLUMN = "Optimal Impression Effectiveness" +OPTIMIZATION_RESULT_ROI_COLUMN = "Optimal ROI" +OPTIMIZATION_RESULT_MROI_COLUMN = "Optimal mROI" +OPTIMIZATION_RESULT_CPC_COLUMN = "Optimal CPC" +OPTIMIZATION_RESULT_IS_REVENUE_KPI_COLUMN = "Is Revenue KPI" + +# R&F Optimization results table column names +RF_OPTIMIZATION_RESULTS = "rf_opt_results" +RF_OPTIMIZATION_RESULT_INITIAL_SPEND_COLUMN = "Initial Spend" +RF_OPTIMIZATION_RESULT_AVG_FREQ_COLUMN = "Optimal Avg Frequency" + +# Optimization results' response curves table column names +OPTIMIZATION_RESPONSE_CURVES = "response_curves" +OPTIMIZATION_RESPONSE_CURVE_SPEND_COLUMN = "Spend" +OPTIMIZATION_RESPONSE_CURVE_INCREMENTAL_OUTCOME_COLUMN = "Incremental Outcome" diff --git a/scenarioplanner/converters/dataframe/converter.py b/scenarioplanner/converters/dataframe/converter.py new file mode 100644 index 000000000..1fe405cdc --- /dev/null +++ b/scenarioplanner/converters/dataframe/converter.py @@ -0,0 +1,42 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""`Converter` class for all dataframe converters.""" + +import abc +from collections.abc import Iterator + +from scenarioplanner.converters import mmm +import pandas as pd + + +__all__ = ["Converter"] + + +class Converter(abc.ABC): + """Converts a trained model and analyses to one or more data frame tables. + + Attributes: + mmm: An `Mmm` proto wrapper. + """ + + def __init__( + self, + mmm_wrapper: mmm.Mmm, + ): + self._mmm = mmm_wrapper + + @abc.abstractmethod + def __call__(self) -> Iterator[tuple[str, pd.DataFrame]]: + raise NotImplementedError() diff --git a/scenarioplanner/converters/dataframe/dataframe_model_converter.py b/scenarioplanner/converters/dataframe/dataframe_model_converter.py new file mode 100644 index 000000000..2cc0a7b33 --- /dev/null +++ b/scenarioplanner/converters/dataframe/dataframe_model_converter.py @@ -0,0 +1,70 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An output converter that denormalizes into flat data frame tables.""" + +from collections.abc import Mapping, Sequence + +from scenarioplanner.converters import mmm_converter +from scenarioplanner.converters.dataframe import budget_opt_converters +from scenarioplanner.converters.dataframe import converter +from scenarioplanner.converters.dataframe import marketing_analyses_converters +from scenarioplanner.converters.dataframe import rf_opt_converters +import pandas as pd + + +__all__ = ["DataFrameModelConverter"] + + +class DataFrameModelConverter(mmm_converter.ModelConverter[pd.DataFrame]): + """Converts a bound `Mmm` model into denormalized flat data frame tables. + + The denormalized, two-dimensional data frame tables are intended to be + directly compiled into sheets in a Google Sheets file to be used as a data + source for a Looker Studio dashboard. + + These data frame tables are: + + * "ModelDiagnostics" + * "ModelFit" + * "MediaOutcome" + * "MediaSpend" + * "MediaROI" + * (Named Incremental Outcome Grids) + * "budget_opt_specs" + * "budget_opt_results" + * "response_curves" + * (Named R&F ROI Grids) + * "rf_opt_specs" + * "rf_opt_results" + """ + + _converters: Sequence[type[converter.Converter]] = ( + marketing_analyses_converters.CONVERTERS + + budget_opt_converters.CONVERTERS + + rf_opt_converters.CONVERTERS + ) + + def __call__(self, **kwargs) -> Mapping[str, pd.DataFrame]: + """Converts bound `Mmm` model proto to named, flat data frame tables.""" + output = {} + + for converter_class in self._converters: + converter_instance = converter_class(self.mmm) # pytype: disable=not-instantiable + for table_name, table_data in converter_instance(): + if output.get(table_name) is not None: + raise ValueError(f"Duplicate table name: {table_name}") + output[table_name] = table_data + + return output diff --git a/scenarioplanner/converters/dataframe/dataframe_model_converter_test.py b/scenarioplanner/converters/dataframe/dataframe_model_converter_test.py new file mode 100644 index 000000000..b2c53a445 --- /dev/null +++ b/scenarioplanner/converters/dataframe/dataframe_model_converter_test.py @@ -0,0 +1,93 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from mmm.v1 import mmm_pb2 as mmm_pb +from mmm.v1.fit import model_fit_pb2 as fit_pb +from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb +from mmm.v1.marketing.optimization import marketing_optimization_pb2 as optimization_pb +from mmm.v1.marketing.optimization import reach_frequency_optimization_pb2 as rf_pb +from mmm.v1.model import mmm_kernel_pb2 as kernel_pb +from scenarioplanner.converters import test_data as td +from scenarioplanner.converters.dataframe import constants as dc +from scenarioplanner.converters.dataframe import dataframe_model_converter as converter + + +_DEFAULT_MMM_PROTO = mmm_pb.Mmm( + mmm_kernel=kernel_pb.MmmKernel( + marketing_data=td.MARKETING_DATA, + ), + model_fit=fit_pb.ModelFit( + results=[ + td.MODEL_FIT_RESULT_TRAIN, + td.MODEL_FIT_RESULT_TEST, + td.MODEL_FIT_RESULT_ALL_DATA, + ] + ), + marketing_analysis_list=td.MARKETING_ANALYSIS_LIST_BOTH_OUTCOMES, + marketing_optimization=optimization_pb.MarketingOptimization( + budget_optimization=budget_pb.BudgetOptimization( + results=[ + td.BUDGET_OPTIMIZATION_RESULT_FIXED_BOTH_OUTCOMES, + td.BUDGET_OPTIMIZATION_RESULT_FLEX_NONREV, + ] + ), + reach_frequency_optimization=rf_pb.ReachFrequencyOptimization( + results=[ + td.RF_OPTIMIZATION_RESULT_FOO, + ] + ), + ), +) + + +class DataFrameModelConverterTest(absltest.TestCase): + + def test_call(self): + conv = converter.DataFrameModelConverter(mmm_proto=_DEFAULT_MMM_PROTO) + + output = conv() + + expected_budget_opt_grid_name1 = "_".join( + [dc.OPTIMIZATION_GRID_NAME_PREFIX, "incremental_outcome_grid_foo"] + ) + + expected_budget_opt_grid_name2 = "_".join( + [dc.OPTIMIZATION_GRID_NAME_PREFIX, "incremental_outcome_grid_bar"] + ) + + expected_rf_opt_grid_name = "_".join( + [dc.RF_OPTIMIZATION_GRID_NAME_PREFIX, "frequency_outcome_grid_foo"] + ) + + for expected_table_name in [ + dc.MODEL_DIAGNOSTICS, + dc.MODEL_FIT, + dc.MEDIA_OUTCOME, + dc.MEDIA_SPEND, + dc.MEDIA_ROI, + expected_budget_opt_grid_name1, + expected_budget_opt_grid_name2, + dc.OPTIMIZATION_SPECS, + dc.OPTIMIZATION_RESULTS, + dc.OPTIMIZATION_RESPONSE_CURVES, + expected_rf_opt_grid_name, + dc.RF_OPTIMIZATION_SPECS, + dc.RF_OPTIMIZATION_RESULTS, + ]: + self.assertIn(expected_table_name, output.keys()) + + +if __name__ == "__main__": + absltest.main() diff --git a/scenarioplanner/converters/dataframe/marketing_analyses_converters.py b/scenarioplanner/converters/dataframe/marketing_analyses_converters.py new file mode 100644 index 000000000..94327d4ae --- /dev/null +++ b/scenarioplanner/converters/dataframe/marketing_analyses_converters.py @@ -0,0 +1,543 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Marketing analyses converters. + +This module defines various classes that convert `MarketingAnalysis`s into flat +dataframes. +""" + +import abc +from collections.abc import Iterator, Sequence +import datetime +import functools +import math +import warnings + +from meridian import constants as c +from mmm.v1.fit import model_fit_pb2 as fit_pb +from scenarioplanner.converters import mmm +from scenarioplanner.converters.dataframe import constants as dc +from scenarioplanner.converters.dataframe import converter +import pandas as pd + + +__all__ = [ + "ModelDiagnosticsConverter", + "ModelFitConverter", + "MediaOutcomeConverter", + "MediaSpendConverter", + "MediaRoiConverter", +] + + +class ModelDiagnosticsConverter(converter.Converter): + """Outputs a "ModelDiagnostics" table. + + When called, this converter yields a data frame with the columns: + + * "Dataset" + * "R Squared" + * "MAPE" + * "wMAPE" + """ + + def __call__(self) -> Iterator[tuple[str, pd.DataFrame]]: + if not self._mmm.model_fit_results: + return + + model_diagnostics_data = [] + for name, result in self._mmm.model_fit_results.items(): + model_diagnostics_data.append(( + name, + result.performance.r_squared, + result.performance.mape, + result.performance.weighted_mape, + )) + yield ( + dc.MODEL_DIAGNOSTICS, + pd.DataFrame( + model_diagnostics_data, + columns=[ + dc.MODEL_DIAGNOSTICS_DATASET_COLUMN, + dc.MODEL_DIAGNOSTICS_R_SQUARED_COLUMN, + dc.MODEL_DIAGNOSTICS_MAPE_COLUMN, + dc.MODEL_DIAGNOSTICS_WMAPE_COLUMN, + ], + ), + ) + + +class ModelFitConverter(converter.Converter): + """Outputs a "ModelFit" table from an "All Data" (*) `Result` dataset. + + Note: If there is no such result dataset, the first one available is used, + instead. + + When called, this converter yields a data frame with the columns: + + * "Time" + A string formatted with Meridian date format: YYYY-mm-dd + * "Expected CI Low" + * "Expected CI High" + * "Expected" + * "Baseline" + * "Actual" + """ + + def __call__(self) -> Iterator[tuple[str, pd.DataFrame]]: + if not self._mmm.model_fit_results: + return + + model_fit_data = [] + for prediction in self._select_model_fit_result().predictions: + time = datetime.datetime( + year=prediction.date_interval.start_date.year, + month=prediction.date_interval.start_date.month, + day=prediction.date_interval.start_date.day, + ).strftime(c.DATE_FORMAT) + + if not prediction.predicted_outcome.uncertainties: + (expected_ci_lo, expected_ci_hi) = (math.nan, math.nan) + else: + if len(prediction.predicted_outcome.uncertainties) > 1: + warnings.warn( + "More than one `Estimate.uncertainties` found in a" + " `Prediction.predicted_outcome` in `ModelFit`; processing only" + " the first confidence interval value." + ) + uncertainty = prediction.predicted_outcome.uncertainties[0] + expected_ci_lo = uncertainty.lowerbound + expected_ci_hi = uncertainty.upperbound + expected = prediction.predicted_outcome.value + actual = prediction.actual_value + + baseline = prediction.predicted_baseline.value + + model_fit_data.append( + (time, expected_ci_lo, expected_ci_hi, expected, baseline, actual) + ) + + yield ( + dc.MODEL_FIT, + pd.DataFrame( + model_fit_data, + columns=[ + dc.MODEL_FIT_TIME_COLUMN, + dc.MODEL_FIT_EXPECTED_CI_LOW_COLUMN, + dc.MODEL_FIT_EXPECTED_CI_HIGH_COLUMN, + dc.MODEL_FIT_EXPECTED_COLUMN, + dc.MODEL_FIT_BASELINE_COLUMN, + dc.MODEL_FIT_ACTUAL_COLUMN, + ], + ), + ) + + def _select_model_fit_result(self) -> fit_pb.Result: + """Returns the model fit `Result` dataset with name "All Data". + + Or else, first available. + """ + model_fit_results = self._mmm.model_fit_results + if not model_fit_results: + raise ValueError("Must have at least one `ModelFit.results` value.") + if c.ALL_DATA in model_fit_results: + result = model_fit_results[c.ALL_DATA] + else: + result = self._mmm.model_fit.results[0] + warnings.warn(f"Using a model fit `Result` with name: '{result.name}'") + + return result + + +class _MarketingAnalysisConverter(converter.Converter, abc.ABC): + """An abstract class for dealing with `MarketingAnalysis`.""" + + @functools.cached_property + def _is_revenue_kpi(self) -> bool: + """Returns true if analyses are using revenue KPI. + + This is done heuristically: by looking at the (presumed existing) "baseline" + `NonMediaAnalysis` proto and seeing if `revenue_kpi` field is defined. If it + is, we assume that all other media analyses must have their `revenue_kpi` + fields defined, too. + + Likewise: if the baseline analysis defines `non_revenue_kpi` and it does not + define `revenue_kpi`, we assume that all other media analyses are based on + a non-revenue KPI. + + Note: This means that this output converter can only work with one type of + KPI as a whole. If a channel's media analysis has both revenue- and + nonrevenue-type KPI defined, for example, only the former will be outputted. + """ + baseline_analysis = self._mmm.tagged_marketing_analyses[ + dc.ANALYSIS_TAG_ALL + ].baseline_analysis + return baseline_analysis.maybe_revenue_outcome is not None + + def __call__(self) -> Iterator[tuple[str, pd.DataFrame]]: + if not self._mmm.marketing_analyses: + return + + yield from self._handle_marketing_analyses(self._mmm.marketing_analyses) + + def _handle_marketing_analyses( + self, analyses: Sequence[mmm.MarketingAnalysis] + ) -> Iterator[tuple[str, pd.DataFrame]]: + raise NotImplementedError() + + +class MediaOutcomeConverter(_MarketingAnalysisConverter): + """Outputs a "MediaOutcome" table. + + When called, this converter yields a data frame with the columns: + + * "Channel Index" + This is to ensure "baseline" and "All Channels" can be sorted to appear + first and last, respectively, in LS dashboard. + * "Channel" + * "Incremental Outcome" + * "Contribution Share" + * "Analysis Period" + A human-readable analysis period. + * "Analysis Date Start" + A string formatted with Meridian date format: YYYY-mm-dd + * "Analysis Date End" + A string formatted with Meridian date format: YYYY-mm-dd + + Note: If the underlying model analysis works with a revenue-type KPI (i.e. + dollar value), then all values in the columns of the output table should be + interpreted the same. Likewise, for non-revenue type KPI. While some + channels may define their outcome analyses in terms of both revenue- and + nonrevenue-type semantics, the output table here remains uniform. + """ + + def _handle_marketing_analyses( + self, analyses: Sequence[mmm.MarketingAnalysis] + ) -> Iterator[tuple[str, pd.DataFrame]]: + media_outcome_data = [] + for marketing_analysis in analyses: + date_start, date_end = marketing_analysis.analysis_date_interval_str + + baseline_outcome: mmm.Outcome = ( + marketing_analysis.baseline_analysis.revenue_outcome + if self._is_revenue_kpi + else marketing_analysis.baseline_analysis.non_revenue_outcome + ) + # "contribution" == incremental outcome + baseline_contrib = baseline_outcome.contribution_pb.value.value + baseline_contrib_share = baseline_outcome.contribution_pb.share.value + + media_outcome_data.append(( + dc.MEDIA_OUTCOME_BASELINE_PSEUDO_CHANNEL_INDEX, + c.BASELINE, + baseline_contrib, + baseline_contrib_share, + marketing_analysis.tag, + date_start, + date_end, + )) + media_analyses = list( + marketing_analysis.channel_mapped_media_analyses.items() + ) + non_media_analyses = list(filter( + lambda x: x[0] != c.BASELINE, + list(marketing_analysis.channel_mapped_non_media_analyses.items()), + )) + all_analyses = media_analyses + non_media_analyses + for channel, media_analysis in all_analyses: + channel_index = ( + dc.MEDIA_OUTCOME_ALL_CHANNELS_PSEUDO_CHANNEL_INDEX + if channel == c.ALL_CHANNELS + else dc.MEDIA_OUTCOME_CHANNEL_INDEX + ) + # Note: use the same revenue- or nonrevenue-type outcome analysis as the + # baseline's. + try: + channel_outcome: mmm.Outcome = ( + media_analysis.revenue_outcome + if self._is_revenue_kpi + else media_analysis.non_revenue_outcome + ) + except ValueError: + warnings.warn( + f"No {'' if self._is_revenue_kpi else 'non'}revenue-type" + " `Outcome` found in the channel media analysis for" + f' "{channel}"' + ) + channel_contrib = math.nan + channel_contrib_share = math.nan + else: + channel_contrib = channel_outcome.contribution_pb.value.value + channel_contrib_share = channel_outcome.contribution_pb.share.value + + media_outcome_data.append(( + channel_index, + channel, + channel_contrib, + channel_contrib_share, + marketing_analysis.tag, + date_start, + date_end, + )) + + yield ( + dc.MEDIA_OUTCOME, + pd.DataFrame( + media_outcome_data, + columns=[ + dc.MEDIA_OUTCOME_CHANNEL_INDEX_COLUMN, + dc.MEDIA_OUTCOME_CHANNEL_COLUMN, + dc.MEDIA_OUTCOME_INCREMENTAL_OUTCOME_COLUMN, + dc.MEDIA_OUTCOME_CONTRIBUTION_SHARE_COLUMN, + dc.ANALYSIS_PERIOD_COLUMN, # using the `tag` field + dc.ANALYSIS_DATE_START_COLUMN, + dc.ANALYSIS_DATE_END_COLUMN, + ], + ), + ) + + +class MediaSpendConverter(_MarketingAnalysisConverter): + """Outputs a "MediaSpend" table. + + When called, this converter yields a data frame with the columns: + + * "Channel" + * "Value" + * "Label" + A human-readable label on what "Value" represents + * "Analysis Period" + A human-readable analysis period. + * "Analysis Date Start" + A string formatted with Meridian date format: YYYY-mm-dd + * "Analysis Date End" + A string formatted with Meridian date format: YYYY-mm-dd + + Note: If the underlying model analysis works with a revenue-type KPI (i.e. + dollar value), then all values in the columns of the output table should + be interpreted the same. Likewise, for non-revenue type KPI. While some + channels may define their outcome analyses in terms of both revenue- and + nonrevenue-type semantics, the output table here remains uniform. + """ + + _share_value_column_index = 1 + _label_column_index = 2 + + def _handle_marketing_analyses( + self, analyses: Sequence[mmm.MarketingAnalysis] + ) -> Iterator[tuple[str, pd.DataFrame]]: + media_spend_data = [] + + for analysis in analyses: + date_start, date_end = analysis.analysis_date_interval_str + data = [] + outcome_share_norm_term = 0.0 + + for ( + channel, + media_analysis, + ) in analysis.channel_mapped_media_analyses.items(): + # Ignore the "All Channels" pseudo-channel. + if channel == c.ALL_CHANNELS: + continue + + spend_share = media_analysis.spend_info_pb.spend_share + + try: + channel_outcome: mmm.Outcome = ( + media_analysis.revenue_outcome + if self._is_revenue_kpi + else media_analysis.non_revenue_outcome + ) + except ValueError: + warnings.warn( + f"No {'' if self._is_revenue_kpi else 'non'}revenue-type" + " `Outcome` found in the channel media analysis for" + f' "{channel}"' + ) + outcome_share = math.nan + else: + outcome_share = channel_outcome.contribution_pb.share.value + outcome_share_norm_term += outcome_share + + data.append([ + channel, + spend_share, + dc.MEDIA_SPEND_LABEL_SPEND_SHARE, + analysis.tag, + date_start, + date_end, + ]) + data.append([ + channel, + outcome_share, + ( + dc.MEDIA_SPEND_LABEL_REVENUE_SHARE + if self._is_revenue_kpi + else dc.MEDIA_SPEND_LABEL_KPI_SHARE + ), + analysis.tag, + date_start, + date_end, + ]) + + # Looker Studio media spend/revenue share charts expect the "revenue + # share" values to be normalized to 100%. This normaliztion provides + # additional information to what the contribution waterfall chart already + # provides. + for d in data: + if d[self._label_column_index] == dc.MEDIA_SPEND_LABEL_SPEND_SHARE: + continue + d[self._share_value_column_index] /= outcome_share_norm_term + + media_spend_data.extend(data) + + yield ( + dc.MEDIA_SPEND, + pd.DataFrame( + media_spend_data, + columns=[ + dc.MEDIA_SPEND_CHANNEL_COLUMN, + dc.MEDIA_SPEND_SHARE_VALUE_COLUMN, + dc.MEDIA_SPEND_LABEL_COLUMN, + dc.ANALYSIS_PERIOD_COLUMN, # using the `tag` field + dc.ANALYSIS_DATE_START_COLUMN, + dc.ANALYSIS_DATE_END_COLUMN, + ], + ), + ) + + +class MediaRoiConverter(_MarketingAnalysisConverter): + """Outputs a "MediaROI" table. + + When called, this converter yields a data frame with the columns: + + * "Channel" + * "Spend" + * "Effectiveness" + * "ROI" + * "ROI CI Low" + The confidence interval (low) of "ROI" above. + * "ROI CI High" + The confidence interval (high) of "ROI" above. + * "Marginal ROI" + * "Is Revenue KPI" + A boolean indicating whether "ROI" refers to revenue or generic KPI. + * "Analysis Period" + A human-readable analysis period. + * "Analysis Date Start" + A string formatted with Meridian date format: YYYY-mm-dd + * "Analysis Date End" + A string formatted with Meridian date format: YYYY-mm-dd + + Note: If the underlying model analysis works with a revenue-type KPI (i.e. + dollar value), then all values in the columns of the output table should + be interpreted the same. Likewise, for non-revenue type KPI. While some + channels may define their outcome analyses in terms of both revenue- and + nonrevenue-type semantics, the output table here remains uniform. + """ + + def _handle_marketing_analyses( + self, analyses: Sequence[mmm.MarketingAnalysis] + ) -> Iterator[tuple[str, pd.DataFrame]]: + media_roi_data = [] + for analysis in analyses: + date_start, date_end = analysis.analysis_date_interval_str + + for ( + channel, + media_analysis, + ) in analysis.channel_mapped_media_analyses.items(): + # Ignore the "All Channels" pseudo-channel. + if channel == c.ALL_CHANNELS: + continue + + spend = media_analysis.spend_info_pb.spend + + try: + channel_outcome: mmm.Outcome = ( + media_analysis.revenue_outcome + if self._is_revenue_kpi + else media_analysis.non_revenue_outcome + ) + except ValueError as exc: + raise ValueError( + f"No {'' if self._is_revenue_kpi else 'non'}revenue-type" + " `Outcome` found in the channel media analysis for" + f' "{channel}"' + ) from exc + else: + effectiveness = channel_outcome.effectiveness_pb.value.value + roi_estimate = channel_outcome.roi_pb + if not roi_estimate.uncertainties: + (roi_ci_lo, roi_ci_hi) = (math.nan, math.nan) + else: + if len(roi_estimate.uncertainties) > 1: + warnings.warn( + "More than one `Estimate.uncertainties` found in an" + ' `Outcome.revenue_outcome.roi` in channel "{channel}".' + " Using the first confidence interval value." + ) + uncertainty = roi_estimate.uncertainties[0] + roi_ci_lo = uncertainty.lowerbound + roi_ci_hi = uncertainty.upperbound + roi = roi_estimate.value + marginal_roi = channel_outcome.marginal_roi_pb.value + is_revenue_kpi = channel_outcome.is_revenue_kpi + + media_roi_data.append([ + channel, + spend, + effectiveness, + roi, + roi_ci_lo, + roi_ci_hi, + marginal_roi, + is_revenue_kpi, + analysis.tag, + date_start, + date_end, + ]) + + yield ( + dc.MEDIA_ROI, + pd.DataFrame( + media_roi_data, + columns=[ + dc.MEDIA_ROI_CHANNEL_COLUMN, + dc.MEDIA_ROI_SPEND_COLUMN, + dc.MEDIA_ROI_EFFECTIVENESS_COLUMN, + dc.MEDIA_ROI_ROI_COLUMN, + dc.MEDIA_ROI_ROI_CI_LOW_COLUMN, + dc.MEDIA_ROI_ROI_CI_HIGH_COLUMN, + dc.MEDIA_ROI_MARGINAL_ROI_COLUMN, + dc.MEDIA_ROI_IS_REVENUE_KPI_COLUMN, + dc.ANALYSIS_PERIOD_COLUMN, + dc.ANALYSIS_DATE_START_COLUMN, + dc.ANALYSIS_DATE_END_COLUMN, + ], + ), + ) + + +CONVERTERS = [ + # These converters create tables for the model analysis charts to use: + ModelDiagnosticsConverter, + ModelFitConverter, + MediaOutcomeConverter, + MediaSpendConverter, + MediaRoiConverter, +] diff --git a/scenarioplanner/converters/dataframe/marketing_analyses_converters_test.py b/scenarioplanner/converters/dataframe/marketing_analyses_converters_test.py new file mode 100644 index 000000000..fc635144b --- /dev/null +++ b/scenarioplanner/converters/dataframe/marketing_analyses_converters_test.py @@ -0,0 +1,776 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +from absl.testing import absltest +from meridian import constants as c +from mmm.v1 import mmm_pb2 as mmm_pb +from mmm.v1.fit import model_fit_pb2 as fit_pb +from scenarioplanner.converters import mmm +from scenarioplanner.converters import test_data as td +from scenarioplanner.converters.dataframe import constants as dc +from scenarioplanner.converters.dataframe import marketing_analyses_converters as converters +import pandas as pd + + +_DEFAULT_MMM_PROTO = mmm_pb.Mmm( + model_fit=fit_pb.ModelFit( + results=[ + td.MODEL_FIT_RESULT_TRAIN, + td.MODEL_FIT_RESULT_TEST, + td.MODEL_FIT_RESULT_ALL_DATA, + ] + ), + marketing_analysis_list=td.MARKETING_ANALYSIS_LIST_BOTH_OUTCOMES, +) + +_NONREVENUE_MMM_PROTO = mmm_pb.Mmm( + model_fit=fit_pb.ModelFit( + results=[td.MODEL_FIT_RESULT_TRAIN, td.MODEL_FIT_RESULT_TEST] + ), + marketing_analysis_list=td.MARKETING_ANALYSIS_LIST_NONREVENUE, +) + + +class ModelDiagnosticsConverterTest(absltest.TestCase): + + def test_call_no_results(self): + conv = converters.ModelDiagnosticsConverter( + mmm_wrapper=mmm.Mmm(mmm_pb.Mmm()) + ) + + self.assertEmpty(list(conv())) + + def test_call(self): + conv = converters.ModelDiagnosticsConverter( + mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.MODEL_DIAGNOSTICS) + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + [ + [ + c.TRAIN, + 0.91, + 60.6, + 55.5, + ], + [ + c.TEST, + 0.99, + 67.7, + 59.8, + ], + [ + c.ALL_DATA, + 0.94, + 60.0, + 55.4, + ], + ], + columns=[ + dc.MODEL_DIAGNOSTICS_DATASET_COLUMN, + dc.MODEL_DIAGNOSTICS_R_SQUARED_COLUMN, + dc.MODEL_DIAGNOSTICS_MAPE_COLUMN, + dc.MODEL_DIAGNOSTICS_WMAPE_COLUMN, + ], + ), + ) + + +class ModelFitConverterTest(absltest.TestCase): + + def test_call_missing_result(self): + conv = converters.ModelFitConverter( + mmm_wrapper=mmm.Mmm( + mmm_proto=mmm_pb.Mmm(model_fit=fit_pb.ModelFit(results=[])) + ) + ) + + self.assertEmpty(list(conv())) + + def test_call(self): + conv = converters.ModelFitConverter(mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO)) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.MODEL_FIT) + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + [ + [ + "2024-01-01", + 90.0, + 110.0, + 100.0, + 90.0, + 105.0, + ], + [ + "2024-01-08", + 100.0, + 120.0, + 110.0, + 109.0, + 115.0, + ], + ], + columns=[ + dc.MODEL_FIT_TIME_COLUMN, + dc.MODEL_FIT_EXPECTED_CI_LOW_COLUMN, + dc.MODEL_FIT_EXPECTED_CI_HIGH_COLUMN, + dc.MODEL_FIT_EXPECTED_COLUMN, + dc.MODEL_FIT_BASELINE_COLUMN, + dc.MODEL_FIT_ACTUAL_COLUMN, + ], + ), + ) + + def test_model_fit_result_no_all_data_result(self): + conv = converters.ModelFitConverter( + mmm_wrapper=mmm.Mmm( + mmm_proto=mmm_pb.Mmm( + model_fit=fit_pb.ModelFit( + results=[ + td.MODEL_FIT_RESULT_TRAIN, + td.MODEL_FIT_RESULT_TEST, + ] + ) + ) + ) + ) + + with self.assertWarnsRegex( + UserWarning, + expected_regex="Using a model fit `Result` with name: 'Train'", + ): + _ = next(conv()) + + +class MediaOutcomeConverterTest(absltest.TestCase): + + def test_call_no_results(self): + conv = converters.MediaOutcomeConverter(mmm_wrapper=mmm.Mmm(mmm_pb.Mmm())) + + self.assertEmpty(list(conv())) + + def test_call_revenue_kpi(self): + conv = converters.MediaOutcomeConverter( + mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.MEDIA_OUTCOME) + + expected_data_frame_data = [] + for date_interval in [td.ALL_DATE_INTERVAL] + td.DATE_INTERVALS: + date_start = datetime.datetime( + year=date_interval.start_date.year, + month=date_interval.start_date.month, + day=date_interval.start_date.day, + ).strftime(c.DATE_FORMAT) + date_end = datetime.datetime( + year=date_interval.end_date.year, + month=date_interval.end_date.month, + day=date_interval.end_date.day, + ).strftime(c.DATE_FORMAT) + tag = date_interval.tag + + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_BASELINE_PSEUDO_CHANNEL_INDEX, + c.BASELINE, + 50.0, + 0.05, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_CHANNEL_INDEX, + "Channel 1", + 100.0, + 0.1, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_CHANNEL_INDEX, + "Channel 2", + 200.0, + 0.2, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_CHANNEL_INDEX, + "RF Channel 1", + 100.0, + 0.1, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_CHANNEL_INDEX, + "RF Channel 2", + 200.0, + 0.2, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_ALL_CHANNELS_PSEUDO_CHANNEL_INDEX, + c.ALL_CHANNELS, + 1000.0, + 1.0, + tag, + date_start, + date_end, + ]) + + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + expected_data_frame_data, + columns=[ + dc.MEDIA_OUTCOME_CHANNEL_INDEX_COLUMN, + dc.MEDIA_OUTCOME_CHANNEL_COLUMN, + dc.MEDIA_OUTCOME_INCREMENTAL_OUTCOME_COLUMN, + dc.MEDIA_OUTCOME_CONTRIBUTION_SHARE_COLUMN, + dc.ANALYSIS_PERIOD_COLUMN, + dc.ANALYSIS_DATE_START_COLUMN, + dc.ANALYSIS_DATE_END_COLUMN, + ], + ), + ) + + def test_call_nonrevenue_kpi(self): + conv = converters.MediaOutcomeConverter( + mmm_wrapper=mmm.Mmm(mmm_proto=_NONREVENUE_MMM_PROTO) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.MEDIA_OUTCOME) + + expected_data_frame_data = [] + for date_interval in [td.ALL_DATE_INTERVAL] + td.DATE_INTERVALS: + date_start = datetime.datetime( + year=date_interval.start_date.year, + month=date_interval.start_date.month, + day=date_interval.start_date.day, + ).strftime(c.DATE_FORMAT) + date_end = datetime.datetime( + year=date_interval.end_date.year, + month=date_interval.end_date.month, + day=date_interval.end_date.day, + ).strftime(c.DATE_FORMAT) + tag = date_interval.tag + + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_BASELINE_PSEUDO_CHANNEL_INDEX, + c.BASELINE, + 40.0, + 0.04, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_CHANNEL_INDEX, + "Channel 1", + 120.0, + 0.12, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_CHANNEL_INDEX, + "Channel 2", + 220.0, + 0.22, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_CHANNEL_INDEX, + "RF Channel 1", + 120.0, + 0.12, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_CHANNEL_INDEX, + "RF Channel 2", + 220.0, + 0.22, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + dc.MEDIA_OUTCOME_ALL_CHANNELS_PSEUDO_CHANNEL_INDEX, + c.ALL_CHANNELS, + 1000.0, + 1.0, + tag, + date_start, + date_end, + ]) + + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + expected_data_frame_data, + columns=[ + dc.MEDIA_OUTCOME_CHANNEL_INDEX_COLUMN, + dc.MEDIA_OUTCOME_CHANNEL_COLUMN, + dc.MEDIA_OUTCOME_INCREMENTAL_OUTCOME_COLUMN, + dc.MEDIA_OUTCOME_CONTRIBUTION_SHARE_COLUMN, + dc.ANALYSIS_PERIOD_COLUMN, + dc.ANALYSIS_DATE_START_COLUMN, + dc.ANALYSIS_DATE_END_COLUMN, + ], + ), + ) + + +class MediaSpendConverterTest(absltest.TestCase): + + def test_call_no_results(self): + conv = converters.MediaSpendConverter(mmm_wrapper=mmm.Mmm(mmm_pb.Mmm())) + + self.assertEmpty(list(conv())) + + def test_call_revenue_kpi(self): + conv = converters.MediaSpendConverter( + mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.MEDIA_SPEND) + + expected_data_frame_data = [] + for date_interval in [td.ALL_DATE_INTERVAL] + td.DATE_INTERVALS: + date_start = datetime.datetime( + year=date_interval.start_date.year, + month=date_interval.start_date.month, + day=date_interval.start_date.day, + ).strftime(c.DATE_FORMAT) + date_end = datetime.datetime( + year=date_interval.end_date.year, + month=date_interval.end_date.month, + day=date_interval.end_date.day, + ).strftime(c.DATE_FORMAT) + tag = date_interval.tag + + expected_data_frame_data.append([ + "Channel 1", + 0.5, + dc.MEDIA_SPEND_LABEL_SPEND_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "Channel 1", + 0.1 / (0.1 + 0.2 + 0.1 + 0.2), + dc.MEDIA_SPEND_LABEL_REVENUE_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "Channel 2", + (1.0 / 6.0), + dc.MEDIA_SPEND_LABEL_SPEND_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "Channel 2", + 0.2 / (0.1 + 0.2 + 0.1 + 0.2), + dc.MEDIA_SPEND_LABEL_REVENUE_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 1", + 0.2, + dc.MEDIA_SPEND_LABEL_SPEND_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 1", + 0.1 / (0.1 + 0.2 + 0.1 + 0.2), + dc.MEDIA_SPEND_LABEL_REVENUE_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 2", + (2.0 / 15.0), + dc.MEDIA_SPEND_LABEL_SPEND_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 2", + 0.2 / (0.1 + 0.2 + 0.1 + 0.2), + dc.MEDIA_SPEND_LABEL_REVENUE_SHARE, + tag, + date_start, + date_end, + ]) + + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + expected_data_frame_data, + columns=[ + dc.MEDIA_SPEND_CHANNEL_COLUMN, + dc.MEDIA_SPEND_SHARE_VALUE_COLUMN, + dc.MEDIA_SPEND_LABEL_COLUMN, + dc.ANALYSIS_PERIOD_COLUMN, + dc.ANALYSIS_DATE_START_COLUMN, + dc.ANALYSIS_DATE_END_COLUMN, + ], + ), + ) + + def test_call_nonrevenue_kpi(self): + conv = converters.MediaSpendConverter( + mmm_wrapper=mmm.Mmm(mmm_proto=_NONREVENUE_MMM_PROTO) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.MEDIA_SPEND) + + expected_data_frame_data = [] + for date_interval in [td.ALL_DATE_INTERVAL] + td.DATE_INTERVALS: + date_start = datetime.datetime( + year=date_interval.start_date.year, + month=date_interval.start_date.month, + day=date_interval.start_date.day, + ).strftime(c.DATE_FORMAT) + date_end = datetime.datetime( + year=date_interval.end_date.year, + month=date_interval.end_date.month, + day=date_interval.end_date.day, + ).strftime(c.DATE_FORMAT) + tag = date_interval.tag + + expected_data_frame_data.append([ + "Channel 1", + 0.5, + dc.MEDIA_SPEND_LABEL_SPEND_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "Channel 1", + 0.12 / (0.12 + 0.22 + 0.12 + 0.22), + dc.MEDIA_SPEND_LABEL_KPI_SHARE, # NOT "revenue"! + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "Channel 2", + (1.0 / 6.0), + dc.MEDIA_SPEND_LABEL_SPEND_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "Channel 2", + 0.22 / (0.12 + 0.22 + 0.12 + 0.22), + dc.MEDIA_SPEND_LABEL_KPI_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 1", + 0.2, + dc.MEDIA_SPEND_LABEL_SPEND_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 1", + 0.12 / (0.12 + 0.22 + 0.12 + 0.22), + dc.MEDIA_SPEND_LABEL_KPI_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 2", + (2.0 / 15.0), + dc.MEDIA_SPEND_LABEL_SPEND_SHARE, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 2", + 0.22 / (0.12 + 0.22 + 0.12 + 0.22), + dc.MEDIA_SPEND_LABEL_KPI_SHARE, + tag, + date_start, + date_end, + ]) + + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + expected_data_frame_data, + columns=[ + dc.MEDIA_SPEND_CHANNEL_COLUMN, + dc.MEDIA_SPEND_SHARE_VALUE_COLUMN, + dc.MEDIA_SPEND_LABEL_COLUMN, + dc.ANALYSIS_PERIOD_COLUMN, + dc.ANALYSIS_DATE_START_COLUMN, + dc.ANALYSIS_DATE_END_COLUMN, + ], + ), + ) + + +class MediaRoiConverterTest(absltest.TestCase): + + def test_call_no_results(self): + conv = converters.MediaRoiConverter(mmm_wrapper=mmm.Mmm(mmm_pb.Mmm())) + + self.assertEmpty(list(conv())) + + def test_call_revenue_kpi(self): + conv = converters.MediaRoiConverter(mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO)) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.MEDIA_ROI) + + expected_data_frame_data = [] + for date_interval in [td.ALL_DATE_INTERVAL] + td.DATE_INTERVALS: + date_start = datetime.datetime( + year=date_interval.start_date.year, + month=date_interval.start_date.month, + day=date_interval.start_date.day, + ).strftime(c.DATE_FORMAT) + date_end = datetime.datetime( + year=date_interval.end_date.year, + month=date_interval.end_date.month, + day=date_interval.end_date.day, + ).strftime(c.DATE_FORMAT) + tag = date_interval.tag + + expected_data_frame_data.append([ + "Channel 1", + 75_000.0, + 2.2, + 1.0, + 0.9, + 1.1, + 10.0, + True, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "Channel 2", + 25_000.0, + 4.4, + 2.0, + 1.8, + 2.2, + 20.0, + True, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 1", + 30_000.0, + 2.2, + 1.0, + 0.9, + 1.1, + 10.0, + True, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 2", + 20_000.0, + 4.4, + 2.0, + 1.8, + 2.2, + 20.0, + True, + tag, + date_start, + date_end, + ]) + + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + expected_data_frame_data, + columns=[ + dc.MEDIA_ROI_CHANNEL_COLUMN, + dc.MEDIA_ROI_SPEND_COLUMN, + dc.MEDIA_ROI_EFFECTIVENESS_COLUMN, + dc.MEDIA_ROI_ROI_COLUMN, + dc.MEDIA_ROI_ROI_CI_LOW_COLUMN, + dc.MEDIA_ROI_ROI_CI_HIGH_COLUMN, + dc.MEDIA_ROI_MARGINAL_ROI_COLUMN, + dc.MEDIA_ROI_IS_REVENUE_KPI_COLUMN, + dc.ANALYSIS_PERIOD_COLUMN, + dc.ANALYSIS_DATE_START_COLUMN, + dc.ANALYSIS_DATE_END_COLUMN, + ], + ), + ) + + def test_call_nonrevenue_kpi(self): + conv = converters.MediaRoiConverter( + mmm_wrapper=mmm.Mmm(mmm_proto=_NONREVENUE_MMM_PROTO) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.MEDIA_ROI) + + expected_data_frame_data = [] + for date_interval in [td.ALL_DATE_INTERVAL] + td.DATE_INTERVALS: + date_start = datetime.datetime( + year=date_interval.start_date.year, + month=date_interval.start_date.month, + day=date_interval.start_date.day, + ).strftime(c.DATE_FORMAT) + date_end = datetime.datetime( + year=date_interval.end_date.year, + month=date_interval.end_date.month, + day=date_interval.end_date.day, + ).strftime(c.DATE_FORMAT) + tag = date_interval.tag + + expected_data_frame_data.append([ + "Channel 1", + 75_000.0, + 6.6, + 12.0, + 10.8, + 13.2, + 120.0, + False, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "Channel 2", + 25_000.0, + 12.1, + 22.0, + 19.8, + 24.2, + 220.0, + False, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 1", + 30_000.0, + 6.6, + 12.0, + 10.8, + 13.2, + 120.0, + False, + tag, + date_start, + date_end, + ]) + expected_data_frame_data.append([ + "RF Channel 2", + 20_000.0, + 12.1, + 22.0, + 19.8, + 24.2, + 220.0, + False, + tag, + date_start, + date_end, + ]) + + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + expected_data_frame_data, + columns=[ + dc.MEDIA_ROI_CHANNEL_COLUMN, + dc.MEDIA_ROI_SPEND_COLUMN, + dc.MEDIA_ROI_EFFECTIVENESS_COLUMN, + dc.MEDIA_ROI_ROI_COLUMN, + dc.MEDIA_ROI_ROI_CI_LOW_COLUMN, + dc.MEDIA_ROI_ROI_CI_HIGH_COLUMN, + dc.MEDIA_ROI_MARGINAL_ROI_COLUMN, + dc.MEDIA_ROI_IS_REVENUE_KPI_COLUMN, + dc.ANALYSIS_PERIOD_COLUMN, + dc.ANALYSIS_DATE_START_COLUMN, + dc.ANALYSIS_DATE_END_COLUMN, + ], + ), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/scenarioplanner/converters/dataframe/rf_opt_converters.py b/scenarioplanner/converters/dataframe/rf_opt_converters.py new file mode 100644 index 000000000..ded87554f --- /dev/null +++ b/scenarioplanner/converters/dataframe/rf_opt_converters.py @@ -0,0 +1,314 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reach and frequency optimization output converters. + +This module defines various classes that convert +`ReachFrequencyOptimizationResult`s into flat dataframes. +""" + +import abc +from collections.abc import Iterator, Sequence + +from meridian import constants as c +from scenarioplanner.converters import mmm +from scenarioplanner.converters.dataframe import common +from scenarioplanner.converters.dataframe import constants as dc +from scenarioplanner.converters.dataframe import converter +import pandas as pd + + +__all__ = [ + "NamedRfOptimizationGridConverter", + "RfOptimizationSpecsConverter", + "RfOptimizationResultsConverter", +] + + +class _RfOptimizationConverter(converter.Converter, abc.ABC): + """An abstract class for dealing with `ReachFrequencyOptimizationResult`s.""" + + def __call__(self) -> Iterator[tuple[str, pd.DataFrame]]: + results = self._mmm.reach_frequency_optimization_results + if not results: + return + + # Validate group IDs. + group_ids = [result.group_id for result in results if result.group_id] + if len(set(group_ids)) != len(group_ids): + raise ValueError( + "Specified group_id must be unique or unset among the given group of" + " results." + ) + + yield from self._handle_rf_optimization_results( + self._mmm.reach_frequency_optimization_results + ) + + def _handle_rf_optimization_results( + self, results: Sequence[mmm.ReachFrequencyOptimizationResult] + ) -> Iterator[tuple[str, pd.DataFrame]]: + raise NotImplementedError() + + +class NamedRfOptimizationGridConverter(_RfOptimizationConverter): + """Outputs named tables for Reach & Frequency optimization grids. + + When called, this converter returns a data frame with the columns: + + * "Group ID" + A UUID generated for each named incremental outcome grid. + * "Channel" + * "Frequency" + * "ROI" + For each named R&F optimization result in the MMM output proto. + """ + + def _handle_rf_optimization_results( + self, results: Sequence[mmm.ReachFrequencyOptimizationResult] + ) -> Iterator[tuple[str, pd.DataFrame]]: + for rf_opt_result in results: + # There should be one unique ID for each result. + group_id = str(rf_opt_result.group_id) if rf_opt_result.group_id else "" + grid = rf_opt_result.frequency_outcome_grid + + # Each grid yields its own data frame table. + rf_optimization_grid_data = [] + for channel, cells in grid.channel_frequency_grids.items(): + for frequency, outcome in cells: + rf_optimization_grid_data.append([ + group_id, + channel, + frequency, + outcome, + ]) + + yield ( + common.create_grid_sheet_name( + dc.RF_OPTIMIZATION_GRID_NAME_PREFIX, grid.name + ), + pd.DataFrame( + rf_optimization_grid_data, + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.RF_OPTIMIZATION_GRID_FREQ_COLUMN, + dc.RF_OPTIMIZATION_GRID_ROI_OUTCOME_COLUMN, + ], + ), + ) + + +class RfOptimizationSpecsConverter(_RfOptimizationConverter): + """Outputs a table of R&F optimization specs. + + When called, this converter returns a data frame with the columns: + + * "Group ID" + A UUID generated for an R&F frequency outcome grid present in the output. + * "Date Interval Start" + * "Date Interval End" + * "Objective" + * "Initial Channel Spend" + * "Channel" + * "Channel Frequency Min" + * "Channel Frequency Max" + """ + + def _handle_rf_optimization_results( + self, results: Sequence[mmm.ReachFrequencyOptimizationResult] + ) -> Iterator[tuple[str, pd.DataFrame]]: + spec_data = [] + for rf_opt_result in results: + # There should be one unique ID for each result. + group_id = str(rf_opt_result.group_id) if rf_opt_result.group_id else "" + spec = rf_opt_result.spec + + objective = common.map_target_metric_str(spec.objective) + # These are the start and end dates for the requested R&F optimization in + # this spec. + date_interval_start, date_interval_end = ( + d.strftime(c.DATE_FORMAT) for d in spec.date_interval.date_interval + ) + rf_date_interval = (date_interval_start, date_interval_end) + + # aka historical spend from marketing data in the model kernel + initial_channel_spends = self._mmm.marketing_data.rf_channel_spends( + rf_date_interval + ) + + # When the constraint of a channel is not specified, that channel will + # have a default frequency constraint of `[1.0, max_freq]`. + # + # NOTE: We assume that the processor has already done this max_freq + # computation. And so we can assert here that channel constraints are + # always fully specified for R&F channels. + channel_constraints = spec.channel_constraints + if not channel_constraints: + raise ValueError( + "R&F optimization spec must have channel constraints specified." + ) + if set([ + channel_constraint.channel_name + for channel_constraint in channel_constraints + ]) != set(self._mmm.marketing_data.rf_channels): + raise ValueError( + "R&F optimization spec must have channel constraints specified for" + " all R&F channels." + ) + + for channel_constraint in channel_constraints: + min_freq = channel_constraint.frequency_constraint.min_frequency or 1.0 + max_freq = channel_constraint.frequency_constraint.max_frequency + if not max_freq: + raise ValueError( + "Channel constraint in R&F optimization spec must have max" + " frequency specified. Missing for channel:" + f" {channel_constraint.channel_name}" + ) + spec_data.append([ + group_id, + date_interval_start, + date_interval_end, + objective, + initial_channel_spends.get(channel_constraint.channel_name, 0.0), + channel_constraint.channel_name, + min_freq, + max_freq, + ]) + + yield ( + dc.RF_OPTIMIZATION_SPECS, + pd.DataFrame( + spec_data, + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_SPEC_DATE_INTERVAL_START_COLUMN, + dc.OPTIMIZATION_SPEC_DATE_INTERVAL_END_COLUMN, + dc.OPTIMIZATION_SPEC_OBJECTIVE_COLUMN, + dc.OPTIMIZATION_SPEC_INITIAL_CHANNEL_SPEND_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.RF_OPTIMIZATION_SPEC_CHANNEL_FREQUENCY_MIN_COLUMN, + dc.RF_OPTIMIZATION_SPEC_CHANNEL_FREQUENCY_MAX_COLUMN, + ], + ), + ) + + +class RfOptimizationResultsConverter(_RfOptimizationConverter): + """Outputs a table of R&F optimization results. + + When called, this converter returns a data frame with the columns: + + * "Group ID" + A UUID generated for a budget optimization result present in the output. + * "Channel" + * "Is Revenue KPI" + Whether the KPI is revenue or not. + * "Initial Spend" + * "Optimal Avg Frequency" + * "Optimal Impression Effectiveness" + * "Optimal ROI" + * "Optimal mROI" + * "Optimal CPC" + """ + + def _handle_rf_optimization_results( + self, results: Sequence[mmm.ReachFrequencyOptimizationResult] + ) -> Iterator[tuple[str, pd.DataFrame]]: + data = [] + for rf_opt_result in results: + group_id = str(rf_opt_result.group_id) if rf_opt_result.group_id else "" + marketing_analysis = rf_opt_result.optimized_marketing_analysis + + spec = rf_opt_result.spec + # These are the start and end dates for the requested R&F optimization in + # this spec. + date_interval_start, date_interval_end = ( + d.strftime(c.DATE_FORMAT) for d in spec.date_interval.date_interval + ) + rf_date_interval = (date_interval_start, date_interval_end) + # aka historical spend from marketing data in the model kernel + initial_budget = self._mmm.marketing_data.rf_channel_spends( + rf_date_interval + ) + + media_channel_analyses = marketing_analysis.channel_mapped_media_analyses + for channel, media_analysis in media_channel_analyses.items(): + # Skip "All Channels" pseudo-channel. + if channel == c.ALL_CHANNELS: + continue + # Skip non-R&F channels. + if channel not in self._mmm.marketing_data.rf_channels: + continue + + initial_spend = initial_budget[channel] + optimal_avg_freq = rf_opt_result.channel_mapped_optimized_frequencies[ + channel + ] + + revenue_outcome = media_analysis.maybe_revenue_outcome + nonrevenue_outcome = media_analysis.maybe_non_revenue_outcome + + # pylint: disable=cell-var-from-loop + def _append_outcome_data( + outcome: mmm.Outcome | None, + is_revenue_kpi: bool, + ) -> None: + if outcome is None: + return + effectiveness = outcome.effectiveness_pb.value.value + roi = outcome.roi_pb.value + mroi = outcome.marginal_roi_pb.value + cpc = outcome.cost_per_contribution_pb.value + data.append([ + group_id, + channel, + is_revenue_kpi, + initial_spend, + optimal_avg_freq, + effectiveness, + roi, + mroi, + cpc, + ]) + + _append_outcome_data(revenue_outcome, True) + _append_outcome_data(nonrevenue_outcome, False) + # pylint: enable=cell-var-from-loop + + yield ( + dc.RF_OPTIMIZATION_RESULTS, + pd.DataFrame( + data, + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.OPTIMIZATION_RESULT_IS_REVENUE_KPI_COLUMN, + dc.RF_OPTIMIZATION_RESULT_INITIAL_SPEND_COLUMN, + dc.RF_OPTIMIZATION_RESULT_AVG_FREQ_COLUMN, + dc.OPTIMIZATION_RESULT_EFFECTIVENESS_COLUMN, + dc.OPTIMIZATION_RESULT_ROI_COLUMN, + dc.OPTIMIZATION_RESULT_MROI_COLUMN, + dc.OPTIMIZATION_RESULT_CPC_COLUMN, + ], + ), + ) + + +CONVERTERS = [ + NamedRfOptimizationGridConverter, + RfOptimizationSpecsConverter, + RfOptimizationResultsConverter, +] diff --git a/scenarioplanner/converters/dataframe/rf_opt_converters_test.py b/scenarioplanner/converters/dataframe/rf_opt_converters_test.py new file mode 100644 index 000000000..58d3fb178 --- /dev/null +++ b/scenarioplanner/converters/dataframe/rf_opt_converters_test.py @@ -0,0 +1,407 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from mmm.v1 import mmm_pb2 as mmm_pb +from mmm.v1.common import kpi_type_pb2 as kpi_type_pb +from mmm.v1.marketing.optimization import marketing_optimization_pb2 as optimization_pb +from mmm.v1.marketing.optimization import reach_frequency_optimization_pb2 as rf_pb +from mmm.v1.model import mmm_kernel_pb2 as kernel_pb +from scenarioplanner.converters import mmm +from scenarioplanner.converters import test_data as td +from scenarioplanner.converters.dataframe import constants as dc +from scenarioplanner.converters.dataframe import rf_opt_converters as converters +import pandas as pd + + +mock = absltest.mock + + +_DEFAULT_MMM_PROTO = mmm_pb.Mmm( + mmm_kernel=kernel_pb.MmmKernel( + marketing_data=td.MARKETING_DATA, + ), + marketing_optimization=optimization_pb.MarketingOptimization( + reach_frequency_optimization=rf_pb.ReachFrequencyOptimization( + results=[ + td.RF_OPTIMIZATION_RESULT_FOO, + ] + ), + ), +) + +_GID = td.RF_OPTIMIZATION_RESULT_FOO.group_id + + +class NamedRfOptimizationGridConverterTest(absltest.TestCase): + + def test_call_no_results(self): + conv = converters.NamedRfOptimizationGridConverter( + mmm_wrapper=mmm.Mmm(mmm_pb.Mmm()) + ) + + self.assertEmpty(list(conv())) + + def test_call(self): + conv = converters.NamedRfOptimizationGridConverter( + mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO) + ) + + dataframes = list(conv()) + + self.assertLen(dataframes, 1) + foo_grid_name, foo_grid_df = dataframes[0] + + expected_foo_grid_name = "_".join( + [dc.RF_OPTIMIZATION_GRID_NAME_PREFIX, "frequency_outcome_grid_foo"] + ) + + self.assertEqual(foo_grid_name, expected_foo_grid_name) + + pd.testing.assert_frame_equal( + foo_grid_df, + pd.DataFrame( + [ + [ + _GID, + "RF Channel 1", + 1.0, + 100.0, + ], + [ + _GID, + "RF Channel 1", + 2.0, + 200.0, + ], + [ + _GID, + "RF Channel 2", + 1.0, + 100.0, + ], + [ + _GID, + "RF Channel 2", + 2.0, + 200.0, + ], + ], + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.RF_OPTIMIZATION_GRID_FREQ_COLUMN, + dc.RF_OPTIMIZATION_GRID_ROI_OUTCOME_COLUMN, + ], + ), + ) + + +class RfOptimizationSpecsConverterTest(absltest.TestCase): + + def test_call_no_results(self): + conv = converters.RfOptimizationSpecsConverter( + mmm_wrapper=mmm.Mmm(mmm_pb.Mmm()) + ) + + self.assertEmpty(list(conv())) + + def test_call(self): + conv = converters.RfOptimizationSpecsConverter( + mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.RF_OPTIMIZATION_SPECS) + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + [ + [ + _GID, + "2024-01-01", + "2024-01-15", + dc.OPTIMIZATION_SPEC_TARGET_METRIC_KPI, + 440.0, + "RF Channel 1", + 1.0, + 5.0, + ], + [ + _GID, + "2024-01-01", + "2024-01-15", + dc.OPTIMIZATION_SPEC_TARGET_METRIC_KPI, + 440.0, + "RF Channel 2", + 1.3, + 6.6, + ], + ], + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_SPEC_DATE_INTERVAL_START_COLUMN, + dc.OPTIMIZATION_SPEC_DATE_INTERVAL_END_COLUMN, + dc.OPTIMIZATION_SPEC_OBJECTIVE_COLUMN, + dc.OPTIMIZATION_SPEC_INITIAL_CHANNEL_SPEND_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.RF_OPTIMIZATION_SPEC_CHANNEL_FREQUENCY_MIN_COLUMN, + dc.RF_OPTIMIZATION_SPEC_CHANNEL_FREQUENCY_MAX_COLUMN, + ], + ), + ) + + def test_call_no_rf_channel_constraints(self): + mmm_proto = mmm_pb.Mmm() + mmm_proto.CopyFrom(_DEFAULT_MMM_PROTO) + mmm_proto.marketing_optimization.reach_frequency_optimization.results[ + 0 + ].spec.ClearField("rf_channel_constraints") + + conv = converters.RfOptimizationSpecsConverter( + mmm_wrapper=mmm.Mmm(mmm_proto) + ) + + with self.assertRaisesRegex( + ValueError, + "R&F optimization spec must have channel constraints specified.", + ): + next(conv()) + + def test_call_missing_an_rf_channel_constraint(self): + mmm_proto = mmm_pb.Mmm() + mmm_proto.CopyFrom(_DEFAULT_MMM_PROTO) + mmm_proto.marketing_optimization.reach_frequency_optimization.results[ + 0 + ].spec.rf_channel_constraints.pop() + + conv = converters.RfOptimizationSpecsConverter( + mmm_wrapper=mmm.Mmm(mmm_proto) + ) + + with self.assertRaisesRegex( + ValueError, + "R&F optimization spec must have channel constraints specified for all" + " R&F channels.", + ): + next(conv()) + + def test_call_missing_max_frequency_constraint(self): + mmm_proto = mmm_pb.Mmm() + mmm_proto.CopyFrom(_DEFAULT_MMM_PROTO) + mmm_proto.marketing_optimization.reach_frequency_optimization.results[ + 0 + ].spec.rf_channel_constraints[1].frequency_constraint.ClearField( + "max_frequency" + ) + + conv = converters.RfOptimizationSpecsConverter( + mmm_wrapper=mmm.Mmm(mmm_proto) + ) + + with self.assertRaisesRegex( + ValueError, + "Channel constraint in R&F optimization spec must have max frequency" + " specified. Missing for channel: RF Channel 2", + ): + next(conv()) + + +class RfOptimizationResultsConverterTest(absltest.TestCase): + + def test_call_no_results(self): + conv = converters.RfOptimizationResultsConverter( + mmm_wrapper=mmm.Mmm(mmm_pb.Mmm()) + ) + + self.assertEmpty(list(conv())) + + def test_call_duplicate_group_id(self): + mmm_proto = mmm_pb.Mmm() + mmm_proto.CopyFrom(_DEFAULT_MMM_PROTO) + mmm_proto.marketing_optimization.reach_frequency_optimization.results.append( + td.RF_OPTIMIZATION_RESULT_FOO + ) + + with self.assertRaisesRegex( + ValueError, "Specified group_id must be unique" + ): + conv = converters.RfOptimizationResultsConverter( + mmm_wrapper=mmm.Mmm(mmm_proto) + ) + next(conv()) + + def test_call(self): + conv = converters.RfOptimizationResultsConverter( + mmm_wrapper=mmm.Mmm(_DEFAULT_MMM_PROTO) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.RF_OPTIMIZATION_RESULTS) + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + [ + [ + _GID, + "RF Channel 1", + True, + 440.0, + 3.3, + 2.2, + 1.0, + 10.0, + 5.0, + ], + [ + _GID, + "RF Channel 1", + False, + 440.0, + 3.3, + 5.5, + 10.0, + 100.0, + 100.0, + ], + [ + _GID, + "RF Channel 2", + True, + 440.0, + 5.6, + 4.4, + 2.0, + 20.0, + 10.0, + ], + [ + _GID, + "RF Channel 2", + False, + 440.0, + 5.6, + 11.0, + 20.0, + 200.0, + 200.0, + ], + ], + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.OPTIMIZATION_RESULT_IS_REVENUE_KPI_COLUMN, + dc.RF_OPTIMIZATION_RESULT_INITIAL_SPEND_COLUMN, + dc.RF_OPTIMIZATION_RESULT_AVG_FREQ_COLUMN, + dc.OPTIMIZATION_RESULT_EFFECTIVENESS_COLUMN, + dc.OPTIMIZATION_RESULT_ROI_COLUMN, + dc.OPTIMIZATION_RESULT_MROI_COLUMN, + dc.OPTIMIZATION_RESULT_CPC_COLUMN, + ], + ), + ) + + def test_call_no_revenue_baseline_outcome(self): + mmm_proto = mmm_pb.Mmm() + mmm_proto.CopyFrom(_DEFAULT_MMM_PROTO) + # Remove revenue-type outcomes from baseline analyses. + rf_opt = mmm_proto.marketing_optimization.reach_frequency_optimization + for rf_result in rf_opt.results: + opt_marketing_analysis = rf_result.optimized_marketing_analysis + for non_media_analysis in opt_marketing_analysis.non_media_analyses: + if non_media_analysis.non_media_name != "baseline": + continue + revenue_outcome_index = None + for i, outcome in enumerate(non_media_analysis.non_media_outcomes): + if outcome.kpi_type == kpi_type_pb.REVENUE: + revenue_outcome_index = i + break + if revenue_outcome_index is not None: + non_media_analysis.non_media_outcomes.pop(revenue_outcome_index) + + conv = converters.RfOptimizationResultsConverter( + mmm_wrapper=mmm.Mmm(mmm_proto) + ) + + name, output_df = next(conv()) + + self.assertEqual(name, dc.RF_OPTIMIZATION_RESULTS) + pd.testing.assert_frame_equal( + output_df, + pd.DataFrame( + [ + [ + _GID, + "RF Channel 1", + True, + 440.0, + 3.3, + 2.2, + 1.0, + 10.0, + 5.0, + ], + [ + _GID, + "RF Channel 1", + False, + 440.0, + 3.3, + 5.5, + 10.0, + 100.0, + 100.0, + ], + [ + _GID, + "RF Channel 2", + True, + 440.0, + 5.6, + 4.4, + 2.0, + 20.0, + 10.0, + ], + [ + _GID, + "RF Channel 2", + False, + 440.0, + 5.6, + 11.0, + 20.0, + 200.0, + 200.0, + ], + ], + columns=[ + dc.OPTIMIZATION_GROUP_ID_COLUMN, + dc.OPTIMIZATION_CHANNEL_COLUMN, + dc.OPTIMIZATION_RESULT_IS_REVENUE_KPI_COLUMN, + dc.RF_OPTIMIZATION_RESULT_INITIAL_SPEND_COLUMN, + dc.RF_OPTIMIZATION_RESULT_AVG_FREQ_COLUMN, + dc.OPTIMIZATION_RESULT_EFFECTIVENESS_COLUMN, + dc.OPTIMIZATION_RESULT_ROI_COLUMN, + dc.OPTIMIZATION_RESULT_MROI_COLUMN, + dc.OPTIMIZATION_RESULT_CPC_COLUMN, + ], + ), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/scenarioplanner/converters/mmm.py b/scenarioplanner/converters/mmm.py new file mode 100644 index 000000000..c9b884b2f --- /dev/null +++ b/scenarioplanner/converters/mmm.py @@ -0,0 +1,743 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides wrappers for the `Mmm` proto. + +This module defines a set of dataclasses that act as high-level wrappers around +the `Mmm` protocol buffer and its nested messages. The primary goal is to offer +a more intuitive API for accessing and manipulating MMM data, abstracting away +the verbosity of the raw protobuf structures. + +The main entry point is the `Mmm` class, which wraps the top-level `mmm_pb2.Mmm` +proto. From an instance of this class, you can navigate through the model's +different components, such as marketing data, model fit results, and various +analyses, using simple properties and methods. + +Typical Usage: + +```python +from mmm.v1 import mmm_pb2 +from lookerstudio.converters import mmm + +# Assume `mmm_proto` is a populated instance of the Mmm proto +mmm_proto = mmm_pb2.Mmm() +# ... + +# Create the main wrapper instance +mmm_wrapper = mmm.Mmm(mmm_proto) + +# Access marketing data and calculate total spends for a given period +marketing_data = mmm_wrapper.marketing_data +total_spends = marketing_data.all_channel_spends( + date_interval=('2025-01-01', '2025-03-31') +) + +# Access budget optimization results +for budget_result in mmm_wrapper.budget_optimization_results: + print(f"Name: {budget_result.name}, Max: {budget_result.spec.max_budget}") +``` +""" + +import abc +import dataclasses +import datetime +import functools +from typing import TypeAlias + +from meridian import constants as c +from meridian.data import time_coordinates as tc +from mmm.v1 import mmm_pb2 as mmm_pb +from mmm.v1.common import date_interval_pb2 as date_interval_pb +from mmm.v1.common import estimate_pb2 as estimate_pb +from mmm.v1.common import kpi_type_pb2 as kpi_type_pb +from mmm.v1.common import target_metric_pb2 as target_metric_pb +from mmm.v1.fit import model_fit_pb2 as fit_pb +from mmm.v1.marketing import marketing_data_pb2 as marketing_data_pb +from mmm.v1.marketing.analysis import marketing_analysis_pb2 as marketing_pb +from mmm.v1.marketing.analysis import media_analysis_pb2 as media_pb +from mmm.v1.marketing.analysis import non_media_analysis_pb2 as non_media_pb +from mmm.v1.marketing.analysis import outcome_pb2 as outcome_pb +from mmm.v1.marketing.analysis import response_curve_pb2 as response_curve_pb +from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb +from mmm.v1.marketing.optimization import constraints_pb2 as constraints_pb +from mmm.v1.marketing.optimization import reach_frequency_optimization_pb2 as rf_pb + +from google.type import date_pb2 as date_pb + + +_DateIntervalTuple: TypeAlias = tuple[datetime.date, datetime.date] + + +@dataclasses.dataclass(frozen=True) +class DateInterval: + """A dataclass wrapper around a tuple of `(start, end)` dates.""" + + date_interval: _DateIntervalTuple + + @property + def start(self) -> datetime.date: + return self.date_interval[0] + + @property + def end(self) -> datetime.date: + return self.date_interval[1] + + def __contains__(self, date: datetime.date) -> bool: + """Returns whether this date interval contains the given date.""" + return self.start <= date < self.end + + def __lt__(self, other: "DateInterval") -> bool: + return self.start < other.start + + +def _to_datetime_date( + date_proto: date_pb.Date, +) -> datetime.date: + """Converts a `Date` proto into a `datetime.date`.""" + return datetime.date( + year=date_proto.year, month=date_proto.month, day=date_proto.day + ) + + +def _to_date_interval_dc( + date_interval: date_interval_pb.DateInterval, +) -> DateInterval: + """Converts a `DateInterval` proto into `DateInterval` dataclass.""" + return DateInterval(( + _to_datetime_date(date_interval.start_date), + _to_datetime_date(date_interval.end_date), + )) + + +@dataclasses.dataclass(frozen=True) +class Outcome: + """A wrapper for `Outcome` proto with derived properties.""" + + outcome_proto: outcome_pb.Outcome + + @property + def is_revenue_kpi(self) -> bool: + return self.outcome_proto.kpi_type == kpi_type_pb.REVENUE + + @property + def is_nonrevenue_kpi(self) -> bool: + return self.outcome_proto.kpi_type == kpi_type_pb.NON_REVENUE + + @property + def contribution_pb(self) -> outcome_pb.Contribution: + return self.outcome_proto.contribution + + @property + def effectiveness_pb(self) -> outcome_pb.Effectiveness: + return self.outcome_proto.effectiveness + + @property + def roi_pb(self) -> estimate_pb.Estimate: + return self.outcome_proto.roi + + @property + def marginal_roi_pb(self) -> estimate_pb.Estimate: + return self.outcome_proto.marginal_roi + + @property + def cost_per_contribution_pb(self) -> estimate_pb.Estimate: + return self.outcome_proto.cost_per_contribution + + +class _OutcomeMixin(abc.ABC): + """Mixin for (non-)media analysis with typed KPI outcome property getters. + + A `MediaAnalysis` or `NonMediaAnalysis` proto is configured with multiple + polymorphic `Outcome`s. In Meridian processors, both types (revenue and + non-revenue) may be present in the analysis container. However, for each type + there should be at most one `Outcome` value. + + This mixin provides both `MediaAnalysis` and `NonMediaAnalysis` dataclasses + with property getters to retrieve typed `Outcome` values. + """ + + @property + @abc.abstractmethod + def _outcome_pbs(self) -> list[outcome_pb.Outcome]: + """Returns a list of `Outcome` protos.""" + raise NotImplementedError() + + @functools.cached_property + def maybe_revenue_outcome(self) -> Outcome | None: + """Returns the revenue-type `Outcome`, or None if it does not exist.""" + for outcome_proto in self._outcome_pbs: + outcome = Outcome(outcome_proto) + if outcome.is_revenue_kpi: + return outcome + return None + + @property + def revenue_outcome(self) -> Outcome: + """Returns the revenue-type `Outcome`, or raises an error if it does not exist.""" + outcome = self.maybe_revenue_outcome + if outcome is None: + raise ValueError( + "No revenue-type `Outcome` found in an expected analysis proto." + ) + return outcome + + @functools.cached_property + def maybe_non_revenue_outcome(self) -> Outcome | None: + """Returns the nonrevenue-type `Outcome`, or None if it does not exist.""" + for outcome_proto in self._outcome_pbs: + outcome = Outcome(outcome_proto) + if outcome.is_nonrevenue_kpi: + return outcome + return None + + @property + def non_revenue_outcome(self) -> Outcome: + """Returns the nonrevenue-type `Outcome`, or raises an error if it does not exist.""" + outcome = self.maybe_non_revenue_outcome + if outcome is None: + raise ValueError( + "No nonrevenue-type `Outcome` found in an expected analysis proto." + ) + return outcome + + +@dataclasses.dataclass(frozen=True) +class MediaAnalysis(_OutcomeMixin): + """A wrapper for `MediaAnalysis` proto with derived properties.""" + + analysis_proto: media_pb.MediaAnalysis + + @property + def channel_name(self) -> str: + return self.analysis_proto.channel_name + + @property + def spend_info_pb(self) -> media_pb.SpendInfo: + return self.analysis_proto.spend_info + + @property + def _outcome_pbs(self) -> list[outcome_pb.Outcome]: + return list(self.analysis_proto.media_outcomes) + + +@dataclasses.dataclass(frozen=True) +class NonMediaAnalysis(_OutcomeMixin): + """A wrapper for `NonMediaAnalysis` proto with derived properties.""" + + analysis_proto: non_media_pb.NonMediaAnalysis + + @property + def non_media_name(self) -> str: + return self.analysis_proto.non_media_name + + @property + def _outcome_pbs(self) -> list[outcome_pb.Outcome]: + return list(self.analysis_proto.non_media_outcomes) + + +@dataclasses.dataclass(frozen=True) +class ResponseCurve: + """A wrapper for `ResponseCurve` proto with derived properties.""" + + channel_name: str + response_curve_proto: response_curve_pb.ResponseCurve + + @property + def input_name(self) -> str: + return self.response_curve_proto.input_name + + @property + def response_points(self) -> list[tuple[float, float]]: + """Returns `(spend, incremental outcome)` tuples for this channel's curve.""" + return [ + (point.input_value, point.incremental_kpi) + for point in self.response_curve_proto.response_points + ] + + +@dataclasses.dataclass(frozen=True) +class MarketingAnalysis: + """A wrapper for `MarketingAnalysis` proto with derived properties.""" + + marketing_analysis_proto: marketing_pb.MarketingAnalysis + + @property + def tag(self) -> str: + return self.marketing_analysis_proto.date_interval.tag + + @functools.cached_property + def analysis_date_interval( + self, + ) -> DateInterval: + return _to_date_interval_dc(self.marketing_analysis_proto.date_interval) + + @property + def analysis_date_interval_str(self) -> tuple[str, str]: + """Returns a tuple of `(date_start, date_end)` as strings.""" + return ( + self.analysis_date_interval.start.strftime(c.DATE_FORMAT), + self.analysis_date_interval.end.strftime(c.DATE_FORMAT), + ) + + @functools.cached_property + def channel_mapped_media_analyses(self) -> dict[str, MediaAnalysis]: + """Returns media analyses mapped to their channel names.""" + return { + analysis.channel_name: MediaAnalysis(analysis) + for analysis in self.marketing_analysis_proto.media_analyses + } + + @functools.cached_property + def channel_mapped_non_media_analyses(self) -> dict[str, NonMediaAnalysis]: + """Returns non-media analyses mapped to their non-media names.""" + return { + analysis.non_media_name: NonMediaAnalysis(analysis) + for analysis in self.marketing_analysis_proto.non_media_analyses + } + + @functools.cached_property + def baseline_analysis(self) -> NonMediaAnalysis: + """Returns a "baseline" non media analysis among the given values. + + Raises: + ValueError: if there is no "baseline" analysis + """ + for non_media_analysis in self.marketing_analysis_proto.non_media_analyses: + if non_media_analysis.non_media_name == c.BASELINE: + return NonMediaAnalysis(non_media_analysis) + else: + raise ValueError( + f"No '{c.BASELINE}' found in the set of `NonMediaAnalysis` for this" + " `MarketingAnalysis`." + ) + + @functools.cached_property + def response_curves(self) -> list[ResponseCurve]: + """Returns a list of `ResponseCurve`s.""" + return [ + ResponseCurve(m_analysis.channel_name, m_analysis.response_curve) + for m_analysis in self.marketing_analysis_proto.media_analyses + ] + + +@dataclasses.dataclass(frozen=True) +class IncrementalOutcomeGrid: + """A wrapper for `IncrementalOutcomeGrid` proto with derived properties.""" + + incremental_outcome_grid_proto: budget_pb.IncrementalOutcomeGrid + + @property + def name(self) -> str: + return self.incremental_outcome_grid_proto.name + + @property + def channel_spend_grids(self) -> dict[str, list[tuple[float, float]]]: + """Returns channels mapped to (spend, incremental outcome) tuples.""" + grid = {} + for channel_cells in self.incremental_outcome_grid_proto.channel_cells: + grid[channel_cells.channel_name] = [ + (cell.spend, cell.incremental_outcome.value) + for cell in channel_cells.cells + ] + return grid + + +class _SpecMixin(abc.ABC): + """Mixin for both budget and R&F optimization specs.""" + + @property + @abc.abstractmethod + def _date_interval_proto(self) -> date_interval_pb.DateInterval: + """Returns the date interval proto.""" + raise NotImplementedError() + + @functools.cached_property + def date_interval(self) -> DateInterval: + """Returns the spec's date interval.""" + date_interval_proto = self._date_interval_proto + return DateInterval(( + datetime.date( + year=date_interval_proto.start_date.year, + month=date_interval_proto.start_date.month, + day=date_interval_proto.start_date.day, + ), + datetime.date( + year=date_interval_proto.end_date.year, + month=date_interval_proto.end_date.month, + day=date_interval_proto.end_date.day, + ), + )) + + +@dataclasses.dataclass(frozen=True) +class BudgetOptimizationSpec(_SpecMixin): + """A wrapper for `BudgetOptimizationSpec` proto with derived properties.""" + + budget_optimization_spec_proto: budget_pb.BudgetOptimizationSpec + + @property + def _date_interval_proto(self) -> date_interval_pb.DateInterval: + return self.budget_optimization_spec_proto.date_interval + + @property + def date_interval_tag(self) -> str: + return self._date_interval_proto.tag + + @property + def objective(self) -> target_metric_pb.TargetMetric: + return self.budget_optimization_spec_proto.objective + + @property + def is_fixed_scenario(self) -> bool: + return ( + self.budget_optimization_spec_proto.WhichOneof("scenario") + == "fixed_budget_scenario" + ) + + @property + def max_budget(self) -> float: + """Returns the maximum budget for this spec. + + Max budget is the total budget for a fixed scenario spec, or the max budget + upper bound for a flexible scenario spec. + """ + if self.is_fixed_scenario: + return ( + self.budget_optimization_spec_proto.fixed_budget_scenario.total_budget + ) + else: + return ( + self.budget_optimization_spec_proto.flexible_budget_scenario.total_budget_constraint.max_budget + ) + + @functools.cached_property + def channel_constraints(self) -> list[budget_pb.ChannelConstraint]: + """Returns a list of `ChannelConstraint`s. + + If the underlying spec proto has no channel constraints, then it is implied + that this spec's maximum budget is applied to them. Returns an empty list in + this case, and it is up to the caller to handle. + """ + return list(self.budget_optimization_spec_proto.channel_constraints) + + +@dataclasses.dataclass(frozen=True) +class RfOptimizationSpec(_SpecMixin): + """A wrapper for `ReachFrequencyOptimizationSpec` proto with derived properties.""" + + rf_optimization_spec_proto: rf_pb.ReachFrequencyOptimizationSpec + + @property + def _date_interval_proto(self) -> date_interval_pb.DateInterval: + return self.rf_optimization_spec_proto.date_interval + + @property + def objective(self) -> target_metric_pb.TargetMetric: + return self.rf_optimization_spec_proto.objective + + @property + def total_budget_constraint(self) -> constraints_pb.BudgetConstraint: + return self.rf_optimization_spec_proto.total_budget_constraint + + @functools.cached_property + def channel_constraints(self) -> list[rf_pb.RfChannelConstraint]: + """Returns a list of `RfChannelConstraint`s.""" + return list(self.rf_optimization_spec_proto.rf_channel_constraints) + + +class _NamedResultMixin(abc.ABC): + """Mixin for named optimization results with assigned group ID.""" + + @property + @abc.abstractmethod + def group_id(self) -> str: + raise NotImplementedError() + + @property + @abc.abstractmethod + def name(self) -> str: + raise NotImplementedError() + + +@dataclasses.dataclass(frozen=True) +class BudgetOptimizationResult(_NamedResultMixin): + """A wrapper for `BudgetOptimizationResult` proto with derived properties.""" + + budget_optimization_result_proto: budget_pb.BudgetOptimizationResult + + @property + def name(self) -> str: + return self.budget_optimization_result_proto.name + + @property + def group_id(self) -> str: + return self.budget_optimization_result_proto.group_id + + @functools.cached_property + def spec(self) -> BudgetOptimizationSpec: + return BudgetOptimizationSpec(self.budget_optimization_result_proto.spec) + + @functools.cached_property + def optimized_marketing_analysis(self) -> MarketingAnalysis: + return MarketingAnalysis( + self.budget_optimization_result_proto.optimized_marketing_analysis + ) + + @functools.cached_property + def incremental_outcome_grid(self) -> IncrementalOutcomeGrid: + return IncrementalOutcomeGrid( + self.budget_optimization_result_proto.incremental_outcome_grid + ) + + @functools.cached_property + def response_curves(self) -> list[ResponseCurve]: + return MarketingAnalysis( + self.budget_optimization_result_proto.optimized_marketing_analysis + ).response_curves + + +@dataclasses.dataclass(frozen=True) +class FrequencyOutcomeGrid: + """A wrapper for `FrequencyOutcomeGrid` proto with derived properties.""" + + frequency_outcome_grid_proto: rf_pb.FrequencyOutcomeGrid + + @property + def name(self) -> str: + return self.frequency_outcome_grid_proto.name + + @property + def channel_frequency_grids(self) -> dict[str, list[tuple[float, float]]]: + """Returns channels mapped to (frequency, outcome) tuples.""" + grid = {} + for channel_cells in self.frequency_outcome_grid_proto.channel_cells: + grid[channel_cells.channel_name] = [ + (cell.reach_frequency.average_frequency, cell.outcome.value) + for cell in channel_cells.cells + ] + return grid + + +@dataclasses.dataclass(frozen=True) +class ReachFrequencyOptimizationResult(_NamedResultMixin): + """A wrapper for `ReachFrequencyOptimizationResult` proto with derived properties.""" + + rf_optimization_result_proto: rf_pb.ReachFrequencyOptimizationResult + + @property + def name(self) -> str: + return self.rf_optimization_result_proto.name + + @property + def group_id(self) -> str: + return self.rf_optimization_result_proto.group_id + + @functools.cached_property + def spec(self) -> RfOptimizationSpec: + return RfOptimizationSpec(self.rf_optimization_result_proto.spec) + + @functools.cached_property + def channel_mapped_optimized_frequencies(self) -> dict[str, float]: + """Returns optimized frequencies mapped to their channel names.""" + return { + optimized_channel_frequency.channel_name: ( + optimized_channel_frequency.optimal_average_frequency + ) + for optimized_channel_frequency in self.rf_optimization_result_proto.optimized_channel_frequencies + } + + @functools.cached_property + def optimized_marketing_analysis(self) -> MarketingAnalysis: + return MarketingAnalysis( + self.rf_optimization_result_proto.optimized_marketing_analysis + ) + + @functools.cached_property + def frequency_outcome_grid(self) -> FrequencyOutcomeGrid: + return FrequencyOutcomeGrid( + self.rf_optimization_result_proto.frequency_outcome_grid + ) + + +@dataclasses.dataclass(frozen=True) +class MarketingData: + """A wrapper for `MarketingData` proto with derived properties.""" + + marketing_data_proto: marketing_data_pb.MarketingData + + @property + def _marketing_data_points( + self, + ) -> list[marketing_data_pb.MarketingDataPoint]: + """Returns a list of `MarketingDataPoint`s.""" + return list(self.marketing_data_proto.marketing_data_points) + + @functools.cached_property + def media_channels(self) -> list[str]: + """Returns unique (non-R&F) media channel names in the marketing data.""" + channels = set() + for data_point in self._marketing_data_points: + for var in data_point.media_variables: + channels.add(var.channel_name) + return sorted(channels) # For deterministic order in iterating. + + @functools.cached_property + def rf_channels(self) -> list[str]: + """Returns unique R&F channel names in the marketing data.""" + channels = set() + for data_point in self._marketing_data_points: + for var in data_point.reach_frequency_variables: + channels.add(var.channel_name) + return sorted(channels) # For deterministic order in iterating. + + @functools.cached_property + def date_intervals(self) -> list[DateInterval]: + """Returns all date intervals in the marketing data.""" + date_intervals = set() + for data_point in self._marketing_data_points: + date_intervals.add(_to_date_interval_dc(data_point.date_interval)) + return sorted(date_intervals) + + def media_channel_spends( + self, date_interval: tc.DateInterval + ) -> dict[str, float]: + """Returns non-RF media channel names mapped to their total spend values, for the given date interval. + + All channel spends in time coordinates between `[start, end)` of the given + date interval are summed up. + + Args: + date_interval: the date interval to query for + + Returns: + A dict of channel names mapped to their total spend values, for the given + date interval. + """ + date_interval = DateInterval(tc.normalize_date_interval(date_interval)) + channel_spends = {channel: 0.0 for channel in self.media_channels} + for data_point in self._marketing_data_points: + # The time coordinate for a marketing data point is the start date of its + # date interval field: test that it is contained within the given interval + data_point_date = _to_date_interval_dc(data_point.date_interval).start + if data_point_date not in date_interval: + continue + for var in data_point.media_variables: + channel_spends[var.channel_name] = ( + channel_spends[var.channel_name] + var.media_spend + ) + return channel_spends + + def rf_channel_spends( + self, date_interval: tc.DateInterval + ) -> dict[str, float]: + """Returns *Reach and Frequency* channel names mapped to their total spend values, for the given date interval. + + All channel spends in time coordinates between `[start, end)` of the given + date interval are summed up. + + Args: + date_interval: the date interval to query for + + Returns: + A dict of channel names mapped to their total spend values, for the given + date interval. + """ + date_interval = DateInterval(tc.normalize_date_interval(date_interval)) + channel_spends = {channel: 0.0 for channel in self.rf_channels} + for data_point in self._marketing_data_points: + # The time coordinate for a marketing data point is the start date of its + # date interval field: test that it is contained within the given interval + data_point_date = _to_date_interval_dc(data_point.date_interval).start + if data_point_date not in date_interval: + continue + for var in data_point.reach_frequency_variables: + channel_spends[var.channel_name] = ( + channel_spends[var.channel_name] + var.spend + ) + return channel_spends + + def all_channel_spends( + self, date_interval: tc.DateInterval + ) -> dict[str, float]: + """Returns *all* channel names mapped to their total spend values, for the given date interval. + + All channel spends in time coordinates between `[start, end)` of the given + date interval are summed up. + + Args: + date_interval: the date interval to query for + + Returns: + A dict of channel names mapped to their total spend values, for the given + date interval. + """ + spends = self.rf_channel_spends(date_interval) + spends.update(self.media_channel_spends(date_interval)) + return spends + + +@dataclasses.dataclass(frozen=True) +class Mmm: + """A wrapper for `Mmm` proto with derived properties.""" + + mmm_proto: mmm_pb.Mmm + + @functools.cached_property + def marketing_data(self) -> MarketingData: + """Returns marketing data inside the MMM model kernel.""" + return MarketingData(self.mmm_proto.mmm_kernel.marketing_data) + + @property + def model_fit(self) -> fit_pb.ModelFit: + return self.mmm_proto.model_fit + + @functools.cached_property + def model_fit_results(self) -> dict[str, fit_pb.Result]: + """Returns each model fit `Result`, mapped to its dataset name.""" + return {result.name: result for result in self.model_fit.results} + + @functools.cached_property + def marketing_analyses(self) -> list[MarketingAnalysis]: + """Returns a list of `MarketingAnalysis` wrappers.""" + return [ + MarketingAnalysis(analysis) + for analysis in self.mmm_proto.marketing_analysis_list.marketing_analyses + ] + + @functools.cached_property + def tagged_marketing_analyses( + self, + ) -> dict[str, MarketingAnalysis]: + """Returns each marketing analysis, mapped to its tag name.""" + return {analysis.tag: analysis for analysis in self.marketing_analyses} + + @functools.cached_property + def budget_optimization_results( + self, + ) -> list[BudgetOptimizationResult]: + """Returns a list of `BudgetOptimizationResult` wrappers.""" + return [ + BudgetOptimizationResult(result) + for result in self.mmm_proto.marketing_optimization.budget_optimization.results + ] + + @functools.cached_property + def reach_frequency_optimization_results( + self, + ) -> list[ReachFrequencyOptimizationResult]: + """Returns a list of `ReachFrequencyOptimizationResult` wrappers.""" + return [ + ReachFrequencyOptimizationResult(result) + for result in self.mmm_proto.marketing_optimization.reach_frequency_optimization.results + ] diff --git a/scenarioplanner/converters/mmm_converter.py b/scenarioplanner/converters/mmm_converter.py new file mode 100644 index 000000000..093c4b4ff --- /dev/null +++ b/scenarioplanner/converters/mmm_converter.py @@ -0,0 +1,58 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Converts a fully specified trained model and analysis output. + +A fully specified trained model and its analyses are in its canonical proto +form. This module provides the API for its conversion to secondary forms +(e.g. flat CSV tables collated in a Sheets file) for immediate consumption +(e.g. as data sources for a Looker Studio dashboard). +""" + +import abc +from collections.abc import Mapping +from typing import Generic, TypeVar + +from mmm.v1 import mmm_pb2 as pb +from scenarioplanner.converters import mmm + + +__all__ = ['ModelConverter'] + + +# The output type of a converter. +O = TypeVar('O') + + +class ModelConverter(abc.ABC, Generic[O]): + """Converts a fully specified trained model to secondary form(s) `O`. + + Attributes: + mmm: An `Mmm` proto containing a trained model and its optional analyses. + """ + + def __init__( + self, + mmm_proto: pb.Mmm, + ): + self._mmm = mmm.Mmm(mmm_proto) + + @property + def mmm(self) -> mmm.Mmm: + return self._mmm + + @abc.abstractmethod + def __call__(self, **kwargs) -> Mapping[str, O]: + """Converts bound `MmmOutput` proto to named secondary form(s) `O`.""" + raise NotImplementedError() diff --git a/scenarioplanner/converters/mmm_test.py b/scenarioplanner/converters/mmm_test.py new file mode 100644 index 000000000..ab51aee2a --- /dev/null +++ b/scenarioplanner/converters/mmm_test.py @@ -0,0 +1,1020 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +from absl.testing import absltest +from meridian import constants as c +from mmm.v1 import mmm_pb2 as mmm_pb +from mmm.v1.common import date_interval_pb2 as date_interval_pb +from mmm.v1.common import estimate_pb2 as estimate_pb +from mmm.v1.common import target_metric_pb2 as target_metric_pb +from mmm.v1.fit import model_fit_pb2 as fit_pb +from mmm.v1.marketing.analysis import marketing_analysis_pb2 as marketing_pb +from mmm.v1.marketing.analysis import media_analysis_pb2 as media_pb +from mmm.v1.marketing.analysis import non_media_analysis_pb2 as non_media_pb +from mmm.v1.marketing.analysis import response_curve_pb2 as response_curve_pb +from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb +from mmm.v1.marketing.optimization import constraints_pb2 as constraints_pb +from mmm.v1.marketing.optimization import marketing_optimization_pb2 as optimization_pb +from mmm.v1.marketing.optimization import reach_frequency_optimization_pb2 as rf_pb +from mmm.v1.model import mmm_kernel_pb2 as kernel_pb +from scenarioplanner.converters import mmm +from scenarioplanner.converters import test_data as td + +from google.type import date_pb2 as date_pb + + +class MmmTest(absltest.TestCase): + + def setUp(self): + super().setUp() + + self._mmm_proto = mmm_pb.Mmm( + mmm_kernel=kernel_pb.MmmKernel( + marketing_data=td.MARKETING_DATA, + ), + model_fit=fit_pb.ModelFit( + results=[ + td.MODEL_FIT_RESULT_TRAIN, + td.MODEL_FIT_RESULT_TEST, + td.MODEL_FIT_RESULT_ALL_DATA, + ] + ), + marketing_optimization=optimization_pb.MarketingOptimization( + budget_optimization=budget_pb.BudgetOptimization( + results=[ + td.BUDGET_OPTIMIZATION_RESULT_FIXED_BOTH_OUTCOMES, + td.BUDGET_OPTIMIZATION_RESULT_FLEX_NONREV, + ] + ), + reach_frequency_optimization=rf_pb.ReachFrequencyOptimization( + results=[ + td.RF_OPTIMIZATION_RESULT_FOO, + ] + ), + ), + ) + + def test_marketing_data(self): + output = mmm.Mmm(self._mmm_proto) + self.assertEqual( + output.marketing_data.marketing_data_proto, + td.MARKETING_DATA, + ) + + def test_model_fit(self): + output = mmm.Mmm(self._mmm_proto) + self.assertEqual(output.model_fit, self._mmm_proto.model_fit) + + def test_model_fit_mapped_results(self): + output = mmm.Mmm(self._mmm_proto) + self.assertEqual( + output.model_fit_results, + { + c.TRAIN: td.MODEL_FIT_RESULT_TRAIN, + c.TEST: td.MODEL_FIT_RESULT_TEST, + c.ALL_DATA: td.MODEL_FIT_RESULT_ALL_DATA, + }, + ) + + def test_budget_optimization_results(self): + output = mmm.Mmm(self._mmm_proto) + self.assertEqual( + output.budget_optimization_results, + [ + mmm.BudgetOptimizationResult( + td.BUDGET_OPTIMIZATION_RESULT_FIXED_BOTH_OUTCOMES + ), + mmm.BudgetOptimizationResult( + td.BUDGET_OPTIMIZATION_RESULT_FLEX_NONREV + ), + ], + ) + + def test_reach_frequency_optimization_results(self): + output = mmm.Mmm(self._mmm_proto) + self.assertEqual( + output.reach_frequency_optimization_results, + [ + mmm.ReachFrequencyOptimizationResult(td.RF_OPTIMIZATION_RESULT_FOO), + ], + ) + + +class MarketingDataTest(absltest.TestCase): + + def test_media_channels(self): + marketing_data = mmm.MarketingData(td.MARKETING_DATA) + self.assertEqual( + marketing_data.media_channels, + ["Channel 1", "Channel 2"], + ) + + def test_rf_channels(self): + marketing_data = mmm.MarketingData(td.MARKETING_DATA) + self.assertEqual( + marketing_data.rf_channels, + ["RF Channel 1", "RF Channel 2"], + ) + + def test_date_intervals(self): + marketing_data = mmm.MarketingData(td.MARKETING_DATA) + self.assertEqual( + marketing_data.date_intervals, + [ + mmm.DateInterval( + (datetime.date(2024, 1, 1), datetime.date(2024, 1, 8)) + ), + mmm.DateInterval( + (datetime.date(2024, 1, 8), datetime.date(2024, 1, 15)) + ), + ], + ) + + def test_media_channel_spends(self): + marketing_data = mmm.MarketingData(td.MARKETING_DATA) + date_interval = ("2024-01-08", "2024-01-15") + self.assertEqual( + marketing_data.media_channel_spends(date_interval), + { + "Channel 1": td.BASE_MEDIA_SPEND * 2 * 1, # x geo x time + "Channel 2": td.BASE_MEDIA_SPEND * 2 * 1, + }, + ) + + def test_media_channel_spends_outside_given_interval(self): + marketing_data = mmm.MarketingData(td.MARKETING_DATA) + date_interval = ("2024-01-15", "2024-01-30") + self.assertEqual( + marketing_data.media_channel_spends(date_interval), + { + "Channel 1": 0.0, + "Channel 2": 0.0, + }, + ) + + def test_rf_channel_spends(self): + marketing_data = mmm.MarketingData(td.MARKETING_DATA) + date_interval = ("2024-01-08", "2024-01-15") + self.assertEqual( + marketing_data.rf_channel_spends(date_interval), + { + "RF Channel 1": td.BASE_RF_MEDIA_SPEND * 2 * 1, # x geo x time + "RF Channel 2": td.BASE_RF_MEDIA_SPEND * 2 * 1, + }, + ) + + def test_rf_channel_spends_outside_given_interval(self): + marketing_data = mmm.MarketingData(td.MARKETING_DATA) + date_interval = ("2024-01-15", "2024-01-30") + self.assertEqual( + marketing_data.rf_channel_spends(date_interval), + { + "RF Channel 1": 0.0, + "RF Channel 2": 0.0, + }, + ) + + def test_all_channel_spends(self): + marketing_data = mmm.MarketingData(td.MARKETING_DATA) + date_interval = ("2024-01-08", "2024-01-15") + self.assertEqual( + marketing_data.all_channel_spends(date_interval), + { + "Channel 1": td.BASE_MEDIA_SPEND * 2 * 1, + "Channel 2": td.BASE_MEDIA_SPEND * 2 * 1, + "RF Channel 1": td.BASE_RF_MEDIA_SPEND * 2 * 1, + "RF Channel 2": td.BASE_RF_MEDIA_SPEND * 2 * 1, + }, + ) + + +class MarketingAnalysisTest(absltest.TestCase): + + def test_channel_mapped_media_analyses(self): + analysis_1 = media_pb.MediaAnalysis(channel_name="Channel 1") + analysis_2 = media_pb.MediaAnalysis(channel_name="Channel 2") + analysis_3 = media_pb.MediaAnalysis(channel_name="Channel 3") + + wrapper = mmm.MarketingAnalysis( + marketing_pb.MarketingAnalysis( + media_analyses=[ + analysis_1, + analysis_2, + analysis_3, + ] + ) + ) + + self.assertEqual( + { + channel: media_analysis.analysis_proto + for channel, media_analysis in wrapper.channel_mapped_media_analyses.items() + }, + { + "Channel 1": analysis_1, + "Channel 2": analysis_2, + "Channel 3": analysis_3, + }, + ) + + def test_get_baseline_analysis(self): + analysis_1 = non_media_pb.NonMediaAnalysis(non_media_name="Channel 1") + analysis_2 = non_media_pb.NonMediaAnalysis(non_media_name="Channel 2") + analysis_3 = non_media_pb.NonMediaAnalysis(non_media_name=c.BASELINE) + + wrapper = mmm.MarketingAnalysis( + marketing_pb.MarketingAnalysis( + non_media_analyses=[ + analysis_1, + analysis_2, + analysis_3, + ] + ) + ) + + self.assertEqual( + wrapper.baseline_analysis.analysis_proto, + analysis_3, + ) + + def test_get_channel_mapped_non_media_analyses(self): + analysis_1 = non_media_pb.NonMediaAnalysis(non_media_name="Channel 1") + analysis_2 = non_media_pb.NonMediaAnalysis(non_media_name="Channel 2") + analysis_3 = non_media_pb.NonMediaAnalysis(non_media_name=c.BASELINE) + + wrapper = mmm.MarketingAnalysis( + marketing_pb.MarketingAnalysis( + non_media_analyses=[ + analysis_1, + analysis_2, + analysis_3, + ] + ) + ) + + self.assertEqual( + { + channel: non_media_analysis.analysis_proto + for channel, non_media_analysis in wrapper.channel_mapped_non_media_analyses.items() + }, + { + "Channel 1": analysis_1, + "Channel 2": analysis_2, + c.BASELINE: analysis_3, + }, + ) + + def test_get_baseline_analysis_missing(self): + analysis_1 = non_media_pb.NonMediaAnalysis(non_media_name="Channel 1") + analysis_2 = non_media_pb.NonMediaAnalysis(non_media_name="Channel 2") + + wrapper = mmm.MarketingAnalysis( + marketing_pb.MarketingAnalysis( + non_media_analyses=[ + analysis_1, + analysis_2, + ] + ) + ) + + with self.assertRaises(ValueError): + _ = wrapper.baseline_analysis + + def test_response_curves(self): + curve_1 = response_curve_pb.ResponseCurve(input_name="Spend 1") + curve_2 = response_curve_pb.ResponseCurve(input_name="Spend 2") + media_analysis_1 = media_pb.MediaAnalysis( + channel_name="Channel 1", + response_curve=curve_1, + ) + media_analysis_2 = media_pb.MediaAnalysis( + channel_name="Channel 2", + response_curve=curve_2, + ) + marketing_analysis = marketing_pb.MarketingAnalysis( + media_analyses=[ + media_analysis_1, + media_analysis_2, + ] + ) + wrapper = mmm.MarketingAnalysis(marketing_analysis) + self.assertEqual( + [curve.response_curve_proto for curve in wrapper.response_curves], + [curve_1, curve_2], + ) + + +class MediaAnalysisTest(absltest.TestCase): + + def test_channel_name(self): + analysis = mmm.MediaAnalysis( + media_pb.MediaAnalysis(channel_name="Channel 1") + ) + self.assertEqual(analysis.channel_name, "Channel 1") + + def test_spend_info_pb(self): + spend_info_pb = media_pb.SpendInfo( + spend=1000.0, + spend_share=0.5, + ) + analysis = mmm.MediaAnalysis( + media_pb.MediaAnalysis(spend_info=spend_info_pb) + ) + self.assertEqual( + analysis.spend_info_pb, + spend_info_pb, + ) + + def test_maybe_revenue_outcome_has_no_revenue_outcome(self): + analysis = mmm.MediaAnalysis( + media_pb.MediaAnalysis( + media_outcomes=[ + td.NON_REVENUE_OUTCOME, + ] + ) + ) + self.assertIsNone(analysis.maybe_revenue_outcome) + + def test_maybe_revenue_outcome_has_both_outcomes(self): + analysis = mmm.MediaAnalysis( + media_pb.MediaAnalysis( + media_outcomes=[ + td.NON_REVENUE_OUTCOME, + td.REVENUE_OUTCOME, + ] + ) + ) + self.assertEqual( + analysis.maybe_revenue_outcome, + mmm.Outcome(td.REVENUE_OUTCOME), + ) + + def test_revenue_outcome(self): + analysis = mmm.MediaAnalysis( + media_pb.MediaAnalysis( + media_outcomes=[ + td.REVENUE_OUTCOME, + ] + ) + ) + self.assertEqual( + analysis.revenue_outcome, + mmm.Outcome(td.REVENUE_OUTCOME), + ) + + def test_revenue_outcome_missing(self): + analysis = mmm.MediaAnalysis( + media_pb.MediaAnalysis( + media_outcomes=[ + td.NON_REVENUE_OUTCOME, + ] + ) + ) + with self.assertRaises(ValueError): + _ = analysis.revenue_outcome + + def test_maybe_non_revenue_outcome_has_no_non_revenue_outcome(self): + analysis = mmm.MediaAnalysis( + media_pb.MediaAnalysis( + media_outcomes=[ + td.REVENUE_OUTCOME, + ] + ) + ) + self.assertIsNone(analysis.maybe_non_revenue_outcome) + + def test_maybe_non_revenue_outcome_has_both_outcomes(self): + analysis = mmm.MediaAnalysis( + media_pb.MediaAnalysis( + media_outcomes=[ + td.REVENUE_OUTCOME, + td.NON_REVENUE_OUTCOME, + ] + ) + ) + self.assertEqual( + analysis.maybe_non_revenue_outcome, + mmm.Outcome(td.NON_REVENUE_OUTCOME), + ) + + def test_non_revenue_outcome(self): + analysis = mmm.MediaAnalysis( + media_pb.MediaAnalysis( + media_outcomes=[ + td.NON_REVENUE_OUTCOME, + ] + ) + ) + self.assertEqual( + analysis.non_revenue_outcome, + mmm.Outcome(td.NON_REVENUE_OUTCOME), + ) + + def test_non_revenue_outcome_missing(self): + analysis = mmm.MediaAnalysis( + media_pb.MediaAnalysis( + media_outcomes=[ + td.REVENUE_OUTCOME, + ] + ) + ) + with self.assertRaises(ValueError): + _ = analysis.non_revenue_outcome + + +class NonMediaAnalysisTest(absltest.TestCase): + + def test_non_media_name(self): + analysis = mmm.NonMediaAnalysis( + non_media_pb.NonMediaAnalysis(non_media_name="Baseline") + ) + self.assertEqual(analysis.non_media_name, "Baseline") + + def test_revenue_outcome(self): + analysis = mmm.NonMediaAnalysis( + non_media_pb.NonMediaAnalysis( + non_media_outcomes=[ + td.REVENUE_OUTCOME, + ] + ) + ) + self.assertEqual( + analysis.revenue_outcome, + mmm.Outcome(td.REVENUE_OUTCOME), + ) + + def test_revenue_outcome_missing(self): + analysis = mmm.NonMediaAnalysis( + non_media_pb.NonMediaAnalysis( + non_media_outcomes=[ + td.NON_REVENUE_OUTCOME, + ] + ) + ) + with self.assertRaisesRegex(ValueError, "No revenue-type `Outcome` found"): + _ = analysis.revenue_outcome + + def test_non_revenue_outcome(self): + analysis = mmm.NonMediaAnalysis( + non_media_pb.NonMediaAnalysis( + non_media_outcomes=[ + td.NON_REVENUE_OUTCOME, + ] + ) + ) + self.assertEqual( + analysis.non_revenue_outcome, + mmm.Outcome(td.NON_REVENUE_OUTCOME), + ) + + def test_non_revenue_outcome_missing(self): + analysis = mmm.NonMediaAnalysis( + non_media_pb.NonMediaAnalysis( + non_media_outcomes=[ + td.REVENUE_OUTCOME, + ] + ) + ) + with self.assertRaisesRegex( + ValueError, "No nonrevenue-type `Outcome` found" + ): + _ = analysis.non_revenue_outcome + + +class ResponseCurveTest(absltest.TestCase): + + def test_input_name(self): + curve = mmm.ResponseCurve( + "Channel 1", + response_curve_pb.ResponseCurve(input_name="Spend"), + ) + self.assertEqual(curve.input_name, "Spend") + + def test_response_points(self): + curve = mmm.ResponseCurve( + "Channel 1", + response_curve_pb.ResponseCurve( + response_points=[ + response_curve_pb.ResponsePoint( + input_value=1000.0, + incremental_kpi=10.0, + ), + response_curve_pb.ResponsePoint( + input_value=2000.0, + incremental_kpi=20.0, + ), + ] + ), + ) + self.assertEqual( + curve.response_points, + [(1000.0, 10.0), (2000.0, 20.0)], + ) + + +class RevenueOutcomeTest(absltest.TestCase): + + def test_is_revenue_kpi(self): + outcome = mmm.Outcome(td.REVENUE_OUTCOME) + self.assertTrue(outcome.is_revenue_kpi) + + def test_is_nonrevenue_kpi(self): + outcome = mmm.Outcome(td.REVENUE_OUTCOME) + self.assertFalse(outcome.is_nonrevenue_kpi) + + def test_contribution_pb(self): + outcome = mmm.Outcome(td.REVENUE_OUTCOME) + self.assertEqual( + outcome.contribution_pb, + td.REVENUE_OUTCOME.contribution, + ) + + def test_effectiveness_pb(self): + outcome = mmm.Outcome(td.REVENUE_OUTCOME) + self.assertEqual( + outcome.effectiveness_pb, + td.REVENUE_OUTCOME.effectiveness, + ) + + def test_roi_pb(self): + outcome = mmm.Outcome(td.REVENUE_OUTCOME) + self.assertEqual(outcome.roi_pb, td.REVENUE_OUTCOME.roi) + + def test_marginal_roi_pb(self): + outcome = mmm.Outcome(td.REVENUE_OUTCOME) + self.assertEqual( + outcome.marginal_roi_pb, + td.REVENUE_OUTCOME.marginal_roi, + ) + + +class NonRevenueOutcomeTest(absltest.TestCase): + + def test_is_revenue_kpi(self): + outcome = mmm.Outcome(td.NON_REVENUE_OUTCOME) + self.assertFalse(outcome.is_revenue_kpi) + + def test_is_nonrevenue_kpi(self): + outcome = mmm.Outcome(td.NON_REVENUE_OUTCOME) + self.assertTrue(outcome.is_nonrevenue_kpi) + + def test_contribution_pb(self): + outcome = mmm.Outcome(td.NON_REVENUE_OUTCOME) + self.assertEqual( + outcome.contribution_pb, + td.NON_REVENUE_OUTCOME.contribution, + ) + + def test_effectiveness_pb(self): + outcome = mmm.Outcome(td.NON_REVENUE_OUTCOME) + self.assertEqual( + outcome.effectiveness_pb, + td.NON_REVENUE_OUTCOME.effectiveness, + ) + + def test_cost_per_contribution_pb(self): + outcome = mmm.Outcome(td.NON_REVENUE_OUTCOME) + self.assertEqual( + outcome.cost_per_contribution_pb, + td.NON_REVENUE_OUTCOME.cost_per_contribution, + ) + + +class IncrementalOutcomeGridTest(absltest.TestCase): + + def test_name(self): + grid = mmm.IncrementalOutcomeGrid( + budget_pb.IncrementalOutcomeGrid(name="Test") + ) + self.assertEqual(grid.name, "Test") + + def test_channel_spend_grids(self): + grid = mmm.IncrementalOutcomeGrid( + budget_pb.IncrementalOutcomeGrid( + channel_cells=[ + budget_pb.IncrementalOutcomeGrid.ChannelCells( + channel_name="Channel 1", + cells=[ + budget_pb.IncrementalOutcomeGrid.Cell( + spend=1000.0, + incremental_outcome=estimate_pb.Estimate( + value=10.0 + ), + ), + budget_pb.IncrementalOutcomeGrid.Cell( + spend=2000.0, + incremental_outcome=estimate_pb.Estimate( + value=20.0 + ), + ), + ], + ), + budget_pb.IncrementalOutcomeGrid.ChannelCells( + channel_name="Channel 2", + cells=[ + budget_pb.IncrementalOutcomeGrid.Cell( + spend=1000.0, + incremental_outcome=estimate_pb.Estimate( + value=10.0 + ), + ), + budget_pb.IncrementalOutcomeGrid.Cell( + spend=2000.0, + incremental_outcome=estimate_pb.Estimate( + value=20.0 + ), + ), + ], + ), + ] + ) + ) + self.assertEqual( + grid.channel_spend_grids, + { + "Channel 1": [ + (1000.0, 10.0), + (2000.0, 20.0), + ], + "Channel 2": [ + (1000.0, 10.0), + (2000.0, 20.0), + ], + }, + ) + + +class BudgetOptimizationSpecTest(absltest.TestCase): + + def test_date_intervals(self): + spec = mmm.BudgetOptimizationSpec( + budget_pb.BudgetOptimizationSpec( + date_interval=date_interval_pb.DateInterval( + start_date=date_pb.Date(year=2024, month=1, day=1), + end_date=date_pb.Date(year=2024, month=1, day=8), + ), + ) + ) + self.assertEqual( + spec.date_interval.date_interval, + (datetime.date(2024, 1, 1), datetime.date(2024, 1, 8)), + ) + + def test_objective(self): + spec = mmm.BudgetOptimizationSpec( + budget_pb.BudgetOptimizationSpec( + objective=target_metric_pb.TargetMetric.ROI, + ) + ) + self.assertEqual(spec.objective, target_metric_pb.TargetMetric.ROI) + + def test_is_fixed_scenario_fixed(self): + spec = mmm.BudgetOptimizationSpec( + budget_pb.BudgetOptimizationSpec( + fixed_budget_scenario=budget_pb.FixedBudgetScenario( + total_budget=1000.0 + ), + ) + ) + self.assertTrue(spec.is_fixed_scenario) + + def test_is_fixed_scenario_flexible(self): + spec = mmm.BudgetOptimizationSpec( + budget_pb.BudgetOptimizationSpec( + flexible_budget_scenario=budget_pb.FlexibleBudgetScenario( + total_budget_constraint=constraints_pb.BudgetConstraint( + min_budget=100.0, + max_budget=1000.0, + ), + ), + ) + ) + self.assertFalse(spec.is_fixed_scenario) + + def test_max_budget_fixed(self): + spec = mmm.BudgetOptimizationSpec( + budget_pb.BudgetOptimizationSpec( + fixed_budget_scenario=budget_pb.FixedBudgetScenario( + total_budget=1000.0 + ), + ) + ) + self.assertEqual(spec.max_budget, 1000.0) + + def test_max_budget_flexible(self): + spec = mmm.BudgetOptimizationSpec( + budget_pb.BudgetOptimizationSpec( + flexible_budget_scenario=budget_pb.FlexibleBudgetScenario( + total_budget_constraint=constraints_pb.BudgetConstraint( + min_budget=100.0, + max_budget=1000.0, + ), + ), + ) + ) + self.assertEqual(spec.max_budget, 1000.0) + + +class RfOptimizationSpecTest(absltest.TestCase): + + def test_date_intervals(self): + spec = mmm.RfOptimizationSpec( + rf_pb.ReachFrequencyOptimizationSpec( + date_interval=date_interval_pb.DateInterval( + start_date=date_pb.Date(year=2024, month=1, day=1), + end_date=date_pb.Date(year=2024, month=1, day=8), + ), + ) + ) + self.assertEqual( + spec.date_interval.date_interval, + (datetime.date(2024, 1, 1), datetime.date(2024, 1, 8)), + ) + + def test_objective(self): + spec = mmm.RfOptimizationSpec( + rf_pb.ReachFrequencyOptimizationSpec( + objective=target_metric_pb.TargetMetric.ROI, + ) + ) + self.assertEqual(spec.objective, target_metric_pb.TargetMetric.ROI) + + def test_total_budget_constraint(self): + budget_constraint_proto = constraints_pb.BudgetConstraint( + min_budget=1000.0, + max_budget=1000.0, + ) + spec = mmm.RfOptimizationSpec( + rf_pb.ReachFrequencyOptimizationSpec( + total_budget_constraint=budget_constraint_proto, + ) + ) + self.assertEqual(spec.total_budget_constraint, budget_constraint_proto) + + def test_channel_constraints(self): + channel_constraints = [ + rf_pb.RfChannelConstraint( + channel_name="Channel 1", + budget_constraint=constraints_pb.BudgetConstraint( + min_budget=1000.0, + max_budget=1000.0, + ), + frequency_constraint=constraints_pb.FrequencyConstraint( + min_frequency=1.0, + max_frequency=10.0, + ), + ), + rf_pb.RfChannelConstraint( + channel_name="Channel 2", + budget_constraint=constraints_pb.BudgetConstraint( + min_budget=1000.0, + max_budget=1000.0, + ), + frequency_constraint=constraints_pb.FrequencyConstraint( + min_frequency=2.0, + max_frequency=12.0, + ), + ), + ] + spec = mmm.RfOptimizationSpec( + rf_pb.ReachFrequencyOptimizationSpec( + rf_channel_constraints=channel_constraints, + ) + ) + self.assertEqual(spec.channel_constraints, channel_constraints) + + +class BudgetOptimizationResultTest(absltest.TestCase): + + def test_group_id(self): + result = mmm.BudgetOptimizationResult( + budget_pb.BudgetOptimizationResult(name="Test", group_id="group") + ) + self.assertEqual(result.group_id, "group") + + def test_name(self): + result = mmm.BudgetOptimizationResult( + budget_pb.BudgetOptimizationResult(name="Test") + ) + self.assertEqual(result.name, "Test") + + def test_spec(self): + result = mmm.BudgetOptimizationResult( + budget_pb.BudgetOptimizationResult( + name="Test", + spec=td.BUDGET_OPTIMIZATION_SPEC_FIXED_ALL_DATES, + ) + ) + self.assertEqual( + result.spec.budget_optimization_spec_proto, + td.BUDGET_OPTIMIZATION_SPEC_FIXED_ALL_DATES, + ) + + def test_optimized_marketing_analysis(self): + analysis_proto = marketing_pb.MarketingAnalysis( + date_interval=date_interval_pb.DateInterval( + start_date=date_pb.Date(year=2024, month=1, day=1), + end_date=date_pb.Date(year=2024, month=1, day=8), + tag="Tag", + ), + ) + + result = mmm.BudgetOptimizationResult( + budget_pb.BudgetOptimizationResult( + name="Test", + optimized_marketing_analysis=analysis_proto, + ) + ) + self.assertEqual( + result.optimized_marketing_analysis, + mmm.MarketingAnalysis(analysis_proto), + ) + + def test_incremental_outcome_grid(self): + grid_proto = budget_pb.IncrementalOutcomeGrid(name="Test") + + result = mmm.BudgetOptimizationResult( + budget_pb.BudgetOptimizationResult( + name="Test", + incremental_outcome_grid=grid_proto, + ) + ) + self.assertEqual( + result.incremental_outcome_grid, + mmm.IncrementalOutcomeGrid(grid_proto), + ) + + def test_response_curves(self): + curve1 = response_curve_pb.ResponseCurve( + input_name="Spend", + response_points=[ + response_curve_pb.ResponsePoint( + input_value=1000.0, + incremental_kpi=10.0, + ), + response_curve_pb.ResponsePoint( + input_value=2000.0, + incremental_kpi=20.0, + ), + ], + ) + curve2 = response_curve_pb.ResponseCurve( + input_name="Spend", + response_points=[ + response_curve_pb.ResponsePoint( + input_value=1010.0, + incremental_kpi=11.0, + ), + response_curve_pb.ResponsePoint( + input_value=2010.0, + incremental_kpi=21.0, + ), + ], + ) + result = mmm.BudgetOptimizationResult( + budget_pb.BudgetOptimizationResult( + name="Test", + optimized_marketing_analysis=marketing_pb.MarketingAnalysis( + media_analyses=[ + media_pb.MediaAnalysis( + channel_name="Channel 1", response_curve=curve1 + ), + media_pb.MediaAnalysis( + channel_name="Channel 2", response_curve=curve2 + ), + ], + ), + ) + ) + self.assertEqual( + result.response_curves, + [ + mmm.ResponseCurve("Channel 1", curve1), + mmm.ResponseCurve("Channel 2", curve2), + ], + ) + + +class FrequencyOutcomeGridTest(absltest.TestCase): + + def test_name(self): + grid = mmm.FrequencyOutcomeGrid(rf_pb.FrequencyOutcomeGrid(name="Test")) + self.assertEqual(grid.name, "Test") + + def test_channel_frequency_grids(self): + grid = mmm.FrequencyOutcomeGrid(td.FREQUENCY_OUTCOME_GRID_FOO) + self.assertEqual( + grid.channel_frequency_grids, + { + "RF Channel 1": [ + (1.0, 100.0), + (2.0, 200.0), + ], + "RF Channel 2": [ + (1.0, 100.0), + (2.0, 200.0), + ], + }, + ) + + +class ReachFrequencyOptimizationResultTest(absltest.TestCase): + + def test_name(self): + result = mmm.ReachFrequencyOptimizationResult( + rf_pb.ReachFrequencyOptimizationResult(name="Test") + ) + self.assertEqual(result.name, "Test") + + def test_group_id(self): + result = mmm.ReachFrequencyOptimizationResult( + rf_pb.ReachFrequencyOptimizationResult(name="Test", group_id="group") + ) + self.assertEqual(result.group_id, "group") + + def test_spec(self): + result = mmm.ReachFrequencyOptimizationResult( + rf_pb.ReachFrequencyOptimizationResult( + name="Test", + spec=td.RF_OPTIMIZATION_SPEC_ALL_DATES, + ) + ) + self.assertEqual( + result.spec.rf_optimization_spec_proto, + td.RF_OPTIMIZATION_SPEC_ALL_DATES, + ) + + def test_channel_mapped_optimized_frequencies(self): + result = mmm.ReachFrequencyOptimizationResult( + rf_pb.ReachFrequencyOptimizationResult( + name="Test", + optimized_channel_frequencies=[ + rf_pb.OptimizedChannelFrequency( + channel_name="Channel 1", + optimal_average_frequency=1.0, + ), + rf_pb.OptimizedChannelFrequency( + channel_name="Channel 2", + optimal_average_frequency=2.0, + ), + ], + ) + ) + + self.assertEqual( + result.channel_mapped_optimized_frequencies, + { + "Channel 1": 1.0, + "Channel 2": 2.0, + }, + ) + + def test_optimized_marketing_analysis(self): + analysis_proto = marketing_pb.MarketingAnalysis( + date_interval=date_interval_pb.DateInterval( + start_date=date_pb.Date(year=2024, month=1, day=1), + end_date=date_pb.Date(year=2024, month=1, day=8), + tag="Tag", + ), + ) + + result = mmm.ReachFrequencyOptimizationResult( + rf_pb.ReachFrequencyOptimizationResult( + name="Test", + optimized_marketing_analysis=analysis_proto, + ) + ) + self.assertEqual( + result.optimized_marketing_analysis, + mmm.MarketingAnalysis(analysis_proto), + ) + + def test_frequency_outcome_grid(self): + grid_proto = rf_pb.FrequencyOutcomeGrid(name="Test") + result = mmm.ReachFrequencyOptimizationResult( + rf_pb.ReachFrequencyOptimizationResult( + name="Test", + frequency_outcome_grid=grid_proto, + ) + ) + self.assertEqual( + result.frequency_outcome_grid, + mmm.FrequencyOutcomeGrid(grid_proto), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/scenarioplanner/converters/sheets.py b/scenarioplanner/converters/sheets.py new file mode 100644 index 000000000..7facf058e --- /dev/null +++ b/scenarioplanner/converters/sheets.py @@ -0,0 +1,152 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility library for compiling Google sheets. + +This library requires authentication. + +* If you're developing locally, set up Application Default Credentials (ADC) +in + your local environment: + + + +* If you're working in Colab, run the following command in a cell to + authenticate: + + ```python + from google.colab import auth + auth.authenticate_user() + ``` + + This command opens a window where you can complete the authentication. +""" + +from collections.abc import Mapping +import dataclasses + +import google.auth +from googleapiclient import discovery +import pandas as pd + + +__all__ = [ + "Spreadsheet", + "upload_to_gsheet", +] + + +# https://developers.google.com/sheets/api/scopes#sheets-scopes +_SCOPES = [ + "https://www.googleapis.com/auth/spreadsheets", + "https://www.googleapis.com/auth/drive", +] +_DEFAULT_SHEET_ID = 0 +_ADD_SHEET_REQUEST_NAME = "addSheet" +_DELETE_SHEET_REQUEST_NAME = "deleteSheet" +_DEFAULT_SPREADSHEET_NAME = "Meridian Looker Studio Data" + + +@dataclasses.dataclass(frozen=True) +class Spreadsheet: + """Spreadsheet data class. + + Attributes: + url: URL of the spreadsheet. + id: ID of the spreadsheet. + sheet_id_by_tab_name: Mapping of sheet tab names to sheet IDs. + """ + + url: str + id: str + sheet_id_by_tab_name: Mapping[str, int] + + +def upload_to_gsheet( + data: Mapping[str, pd.DataFrame], + credentials: google.auth.credentials.Credentials | None = None, + spreadsheet_name: str = _DEFAULT_SPREADSHEET_NAME, +) -> Spreadsheet: + """Creates new spreadsheet. + + Loads pre-authorized user credentials from the environment. + + Args: + data: Mapping of tab name to dataframe. + credentials: Optional credentials from the user. + spreadsheet_name: Name of the spreadsheet. + + Returns: + Spreadsheet data class. + """ + if credentials is None: + credentials, _ = google.auth.default(scopes=_SCOPES) + service = discovery.build("sheets", "v4", credentials=credentials) + spreadsheet = ( + service.spreadsheets() + .create(body={"properties": {"title": spreadsheet_name}}) + .execute() + ) + spreadsheet_id = spreadsheet["spreadsheetId"] + + # Build requests to add a worksheets and fill them in. + tab_requests = [] + values_request_body = { + "data": [], + "valueInputOption": "USER_ENTERED", + } + for tab_name, dataframe in data.items(): + tab_requests.append( + {_ADD_SHEET_REQUEST_NAME: {"properties": {"title": tab_name}}} + ) + values_request_body["data"].append({ + "values": [ + dataframe.columns.values.tolist() + ] + dataframe.values.tolist(), + "range": f"{tab_name}!A1", + }) + # Delete first default tab + tab_requests.append( + {_DELETE_SHEET_REQUEST_NAME: {"sheetId": _DEFAULT_SHEET_ID}} + ) + + created_tab_objects = ( + service.spreadsheets() + .batchUpdate( + spreadsheetId=spreadsheet_id, body={"requests": tab_requests} + ) + .execute() + ) + + sheet_id_by_tab_name = dict() + for tab in created_tab_objects["replies"]: + if _ADD_SHEET_REQUEST_NAME not in tab: + continue + add_sheet_response_properties = tab.get(_ADD_SHEET_REQUEST_NAME).get( + "properties" + ) + tab_name = add_sheet_response_properties.get("title") + sheet_id = add_sheet_response_properties.get("sheetId") + sheet_id_by_tab_name[tab_name] = sheet_id + + # Fill in the data. + service.spreadsheets().values().batchUpdate( + spreadsheetId=spreadsheet_id, body=values_request_body + ).execute() + + return Spreadsheet( + url=spreadsheet.get("spreadsheetUrl"), + id=spreadsheet_id, + sheet_id_by_tab_name=sheet_id_by_tab_name, + ) diff --git a/scenarioplanner/converters/sheets_test.py b/scenarioplanner/converters/sheets_test.py new file mode 100644 index 000000000..b2f0d06b6 --- /dev/null +++ b/scenarioplanner/converters/sheets_test.py @@ -0,0 +1,168 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from google.auth import credentials as auth_credentials +import google.auth.transport.requests +from googleapiclient import discovery +from scenarioplanner.converters import sheets +import pandas as pd + + +class SheetsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + # Mock auth. + self.mock_auth_default = self.enter_context( + mock.patch.object( + google.auth, + 'default', + autospec=True, + ) + ) + mock_creds = mock.create_autospec( + auth_credentials.Credentials, instance=True + ) + self.mock_auth_default.return_value = mock_creds, None + self.api_mock = mock.Mock() + + # Mock discovery build. + self.build_mock = self.enter_context( + mock.patch.object( + discovery, + 'build', + return_value=self.api_mock, + autospec=True, + ) + ) + + # Mock spreadsheet create. + self.api_create_mock = self.api_mock.spreadsheets().create + self.api_create_mock.return_value.execute.return_value = { + 'spreadsheetId': 'test_id', + 'title': 'Untitled spreadsheet', + 'spreadsheetUrl': 'test_url', + } + + # Mock sheet values batch update. + self.api_batch_update_mock = self.api_mock.spreadsheets().batchUpdate + self.api_batch_update_mock.return_value.execute.return_value = { + 'replies': [ + { + 'addSheet': { + 'properties': { + 'sheetId': 1, + 'title': 'Tab1', + } + } + }, + { + 'addSheet': { + 'properties': { + 'sheetId': 2, + 'title': 'Tab2', + } + } + }, + {'deleteSheet': {'sheetId': 0}}, + ] + } + self.api_values_batch_update_mock = ( + self.api_mock.spreadsheets().values().batchUpdate + ) + + def test_upload_to_gsheet_creates_spreadsheet_with_name(self): + spreadsheet_name = 'test_title' + sheets.upload_to_gsheet({}, spreadsheet_name=spreadsheet_name) + self.api_create_mock.assert_called_once_with( + body={'properties': {'title': spreadsheet_name}}, + ) + + def test_upload_to_gsheet_output_is_correct(self): + spreadsheet = sheets.upload_to_gsheet({}) + self.assertEqual( + spreadsheet, + sheets.Spreadsheet( + id='test_id', + url='test_url', + sheet_id_by_tab_name={ + 'Tab1': 1, + 'Tab2': 2, + }, + ), + ) + + def test_upload_to_gsheet_has_called_batch_update_with_correct_requests(self): + dict_of_dataframes = { + 'Tab1': pd.DataFrame({ + 'NumberColumn': [1, 2, 3], + 'StringColumn': ['abc', '', 'def'], + 'NoneColumn': [None, None, None], + }), + 'Tab2': pd.DataFrame({ + 'FloatColumn': [7.1, 8.2], + 'DateColumn': ['2021-02-01,2021-02-22', '2021-02-01,2021-02-22'], + }), + } + sheets.upload_to_gsheet(dict_of_dataframes) + + self.api_batch_update_mock.assert_has_calls([ + mock.call( + spreadsheetId='test_id', + body={ + 'requests': [ + {'addSheet': {'properties': {'title': 'Tab1'}}}, + {'addSheet': {'properties': {'title': 'Tab2'}}}, + {'deleteSheet': {'sheetId': 0}}, + ] + }, + ), + mock.call().execute(), + ]) + values_request_body = { + 'data': [ + { + 'values': [ + ['NumberColumn', 'StringColumn', 'NoneColumn'], + [1, 'abc', None], + [2, '', None], + [3, 'def', None], + ], + 'range': 'Tab1!A1', + }, + { + 'values': [ + ['FloatColumn', 'DateColumn'], + [7.1, '2021-02-01,2021-02-22'], + [8.2, '2021-02-01,2021-02-22'], + ], + 'range': 'Tab2!A1', + }, + ], + 'valueInputOption': 'USER_ENTERED', + } + self.api_values_batch_update_mock.assert_has_calls([ + mock.call( + spreadsheetId='test_id', + body=values_request_body, + ), + mock.call().execute(), + ]) + + +if __name__ == '__main__': + absltest.main() diff --git a/scenarioplanner/converters/test_data.py b/scenarioplanner/converters/test_data.py new file mode 100644 index 000000000..62bbcf3db --- /dev/null +++ b/scenarioplanner/converters/test_data.py @@ -0,0 +1,714 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared test data.""" + +from collections.abc import Iterator, Sequence + +from meridian import constants as c +from mmm.v1.common import date_interval_pb2 as date_interval_pb +from mmm.v1.common import estimate_pb2 as estimate_pb +from mmm.v1.common import kpi_type_pb2 as kpi_type_pb +from mmm.v1.common import target_metric_pb2 as target_metric_pb +from mmm.v1.fit import model_fit_pb2 as fit_pb +from mmm.v1.marketing import marketing_data_pb2 as marketing_data_pb +from mmm.v1.marketing.analysis import marketing_analysis_pb2 as marketing_pb +from mmm.v1.marketing.analysis import media_analysis_pb2 as media_pb +from mmm.v1.marketing.analysis import non_media_analysis_pb2 as non_media_pb +from mmm.v1.marketing.analysis import outcome_pb2 as outcome_pb +from mmm.v1.marketing.analysis import response_curve_pb2 as response_curve_pb +from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb +from mmm.v1.marketing.optimization import constraints_pb2 as constraints_pb +from mmm.v1.marketing.optimization import reach_frequency_optimization_pb2 as rf_pb +from scenarioplanner.converters.dataframe import constants as cc + +from google.type import date_pb2 as date_pb + + +DATES = [ + date_pb.Date(year=2024, month=1, day=1), + date_pb.Date(year=2024, month=1, day=8), + date_pb.Date(year=2024, month=1, day=15), +] +DATE_INTERVALS = [ + date_interval_pb.DateInterval( + start_date=DATES[0], + end_date=DATES[1], + tag="Week1", + ), + date_interval_pb.DateInterval( + start_date=DATES[1], + end_date=DATES[2], + tag="Week2", + ), +] +ALL_DATE_INTERVAL = date_interval_pb.DateInterval( + start_date=DATES[0], + end_date=DATES[2], + tag=cc.ANALYSIS_TAG_ALL, +) + + +GEO_INFOS = [ + marketing_data_pb.GeoInfo( + geo_id="geo-1", + population=100, + ), + marketing_data_pb.GeoInfo( + geo_id="geo-2", + population=200, + ), +] + + +MEDIA_CHANNELS = [ + "Channel 1", + "Channel 2", +] +RF_CHANNELS = [ + "RF Channel 1", + "RF Channel 2", +] + + +BASE_MEDIA_SPEND = 100.0 +BASE_RF_MEDIA_SPEND = 110.0 + + +def _create_marketing_data( + create_rf_data: bool = True, +) -> Iterator[marketing_data_pb.MarketingDataPoint]: + """Generator for default `MarketingDataPoint`s for each geo and date interval defined above.""" + for geo_info in GEO_INFOS: + for date_interval in DATE_INTERVALS: + media_vars = [] + rf_vars = [] + for channel in MEDIA_CHANNELS: + media_var = marketing_data_pb.MediaVariable( + channel_name=channel, + # For simplicity, set all media spend to be the same across all + # channels and across all geo and time dimensions. + # Add function parameters if more sophisticated test data + # generator is warranted here. + media_spend=BASE_MEDIA_SPEND, + ) + media_vars.append(media_var) + if create_rf_data: + for channel in RF_CHANNELS: + rf_media_var = marketing_data_pb.ReachFrequencyVariable( + channel_name=channel, + spend=BASE_RF_MEDIA_SPEND, + reach=10_000, + average_frequency=1.1, + ) + rf_vars.append(rf_media_var) + yield marketing_data_pb.MarketingDataPoint( + date_interval=date_interval, + geo_info=geo_info, + media_variables=media_vars, + reach_frequency_variables=rf_vars, + # `kpi` and `control_variables` fields are not set, since no test + # needs it just yet. Fill them in when needed. + ) + + +MARKETING_DATA = marketing_data_pb.MarketingData( + marketing_data_points=list(_create_marketing_data()), +) + + +PERFORMANCE_TEST = fit_pb.Performance( + r_squared=0.99, + mape=67.7, + weighted_mape=59.8, + rmse=55.05, +) +PERFORMANCE_TRAIN = fit_pb.Performance( + r_squared=0.91, + mape=60.6, + weighted_mape=55.5, + rmse=59.87, +) +PERFORMANCE_ALL_DATA = fit_pb.Performance( + r_squared=0.94, + mape=60.0, + weighted_mape=55.4, + rmse=52.0, +) + + +def _create_model_fit_result( + name: str, + performance: fit_pb.Performance, +) -> fit_pb.Result: + return fit_pb.Result( + name=name, + performance=performance, + predictions=[ + fit_pb.Prediction( + date_interval=DATE_INTERVALS[0], + predicted_outcome=estimate_pb.Estimate( + value=100.0, + uncertainties=[ + estimate_pb.Estimate.Uncertainty( + probability=0.9, + lowerbound=90.0, + upperbound=110.0, + ) + ], + ), + predicted_baseline=estimate_pb.Estimate( + value=90.0, + uncertainties=[ + estimate_pb.Estimate.Uncertainty( + probability=0.9, + lowerbound=89.0, + upperbound=111.0, + ) + ], + ), + actual_value=105.0, + ), + fit_pb.Prediction( + date_interval=DATE_INTERVALS[1], + predicted_outcome=estimate_pb.Estimate( + value=110.0, + uncertainties=[ + estimate_pb.Estimate.Uncertainty( + probability=0.9, + lowerbound=100.0, + upperbound=120.0, + ) + ], + ), + predicted_baseline=estimate_pb.Estimate( + value=109.0, + uncertainties=[ + estimate_pb.Estimate.Uncertainty( + probability=0.9, + lowerbound=90.0, + upperbound=125.0, + ) + ], + ), + actual_value=115.0, + ), + ], + ) + + +MODEL_FIT_RESULT_TEST = _create_model_fit_result( + name=c.TEST, + performance=PERFORMANCE_TEST, +) +MODEL_FIT_RESULT_TRAIN = _create_model_fit_result( + name=c.TRAIN, + performance=PERFORMANCE_TRAIN, +) +MODEL_FIT_RESULT_ALL_DATA = _create_model_fit_result( + name=c.ALL_DATA, + performance=PERFORMANCE_ALL_DATA, +) + + +def create_outcome( + incremental_outcome: float, + pct_of_contribution: float, + effectiveness: float, + roi: float, + mroi: float, + cpik: float, + is_revenue_type: bool, +) -> outcome_pb.Outcome: + return outcome_pb.Outcome( + kpi_type=( + kpi_type_pb.REVENUE if is_revenue_type else kpi_type_pb.NON_REVENUE + ), + contribution=outcome_pb.Contribution( + value=estimate_pb.Estimate(value=incremental_outcome), + share=estimate_pb.Estimate(value=pct_of_contribution), + ), + effectiveness=outcome_pb.Effectiveness( + media_unit=c.IMPRESSIONS, + value=estimate_pb.Estimate(value=effectiveness), + ), + roi=estimate_pb.Estimate( + value=roi, + uncertainties=[ + estimate_pb.Estimate.Uncertainty( + probability=0.9, + lowerbound=roi * 0.9, + upperbound=roi * 1.1, + ) + ], + ), + marginal_roi=estimate_pb.Estimate(value=mroi), + cost_per_contribution=estimate_pb.Estimate( + value=cpik, + uncertainties=[ + estimate_pb.Estimate.Uncertainty( + probability=0.9, + lowerbound=cpik * 0.9, + upperbound=cpik * 1.1, + ) + ], + ), + ) + + +REVENUE_OUTCOME = create_outcome( + incremental_outcome=100.0, + pct_of_contribution=0.1, + effectiveness=3.3, + roi=1.0, + mroi=10.0, + cpik=5.0, + is_revenue_type=True, +) + +NON_REVENUE_OUTCOME = create_outcome( + incremental_outcome=100.0, + pct_of_contribution=0.1, + effectiveness=4.4, + roi=10.0, + mroi=100.0, + cpik=100.0, + is_revenue_type=False, +) + + +SPENDS = { + MEDIA_CHANNELS[0]: 75_000, + MEDIA_CHANNELS[1]: 25_000, + RF_CHANNELS[0]: 30_000, + RF_CHANNELS[1]: 20_000, +} +TOTAL_SPEND = sum(SPENDS.values()) +SPENDS[c.ALL_CHANNELS] = TOTAL_SPEND + + +def create_media_analysis( + channel: str, + multiplier: float = 1.0, + make_revenue_outcome: bool = True, + make_non_revenue_outcome: bool = True, +) -> media_pb.MediaAnalysis: + """Creates a `MediaAnalysis` test proto.""" + # `multiplier` is used to create unique metric numbers for the given channel + # from the base template metrics above. + outcomes = [] + if make_revenue_outcome: + outcomes.append( + create_outcome( + incremental_outcome=100.0 * multiplier, + pct_of_contribution=0.1 * multiplier, + effectiveness=2.2 * multiplier, + roi=1.0 * multiplier, + mroi=10.0 * multiplier, + cpik=5.0 * multiplier, + is_revenue_type=True, + ) + ) + if make_non_revenue_outcome: + outcomes.append( + create_outcome( + incremental_outcome=100.0 * multiplier, + pct_of_contribution=0.1 * multiplier, + effectiveness=5.5 * multiplier, + roi=10.0 * multiplier, + mroi=100.0 * multiplier, + cpik=100.0 * multiplier, + is_revenue_type=False, + ) + ) + + response_curve = response_curve_pb.ResponseCurve( + input_name="Spend", + response_points=[ + response_curve_pb.ResponsePoint( + input_value=1 * multiplier, + incremental_kpi=100.0 * multiplier, + ), + response_curve_pb.ResponsePoint( + input_value=2 * multiplier, + incremental_kpi=200.0 * multiplier, + ), + ], + ) + return media_pb.MediaAnalysis( + channel_name=channel, + spend_info=media_pb.SpendInfo( + spend=SPENDS[channel], + spend_share=SPENDS[channel] / TOTAL_SPEND, + ), + media_outcomes=outcomes, + response_curve=response_curve, + ) + + +MEDIA_ANALYSES_BOTH_OUTCOMES = [ + create_media_analysis(channel, (idx + 1)) + for (idx, channel) in enumerate(MEDIA_CHANNELS) +] +RF_ANALYSES_BOTH_OUTCOMES = [ + create_media_analysis(channel, (idx + 1)) + for (idx, channel) in enumerate(RF_CHANNELS) +] +MEDIA_ANALYSES_NONREVENUE = [ + create_media_analysis( + channel, + # use a different multiplier value to distinquish from the above + (idx + 1.2), + make_revenue_outcome=False, + ) + for (idx, channel) in enumerate(MEDIA_CHANNELS) +] +RF_ANALYSES_NONREVENUE = [ + create_media_analysis( + channel, + # use a different multiplier value to distinquish from the above + (idx + 1.2), + make_revenue_outcome=False, + ) + for (idx, channel) in enumerate(RF_CHANNELS) +] + +ALL_CHANNELS_ANALYSIS_BOTH_OUTCOMES = create_media_analysis( + c.ALL_CHANNELS, multiplier=10 +) +ALL_CHANNELS_ANALYSIS_NONREVENUE = create_media_analysis( + c.ALL_CHANNELS, multiplier=12, make_revenue_outcome=False +) + +BASELINE_NONREVENUE_OUTCOME = create_outcome( + incremental_outcome=40.0, + pct_of_contribution=0.04, + effectiveness=4.4, + cpik=75.0, + roi=7.0, + mroi=70.0, + is_revenue_type=False, +) +BASELINE_ANALYSIS_NONREVENUE = non_media_pb.NonMediaAnalysis( + non_media_name=c.BASELINE, + non_media_outcomes=[BASELINE_NONREVENUE_OUTCOME], +) + +BASELINE_REVENUE_OUTCOME = create_outcome( + incremental_outcome=50.0, + pct_of_contribution=0.05, + effectiveness=5.5, + roi=1.0, + mroi=10.0, + cpik=0.5, + is_revenue_type=True, +) +BASELINE_ANALYSIS_REVENUE = non_media_pb.NonMediaAnalysis( + non_media_name=c.BASELINE, + non_media_outcomes=[BASELINE_REVENUE_OUTCOME], +) + +BASELINE_ANALYSIS_BOTH_OUTCOMES = non_media_pb.NonMediaAnalysis( + non_media_name=c.BASELINE, + non_media_outcomes=[BASELINE_NONREVENUE_OUTCOME, BASELINE_REVENUE_OUTCOME], +) + + +def create_marketing_analysis( + date_interval: date_interval_pb.DateInterval, + baseline_analysis: non_media_pb.NonMediaAnalysis = BASELINE_ANALYSIS_BOTH_OUTCOMES, + explicit_channel_analyses: Sequence[media_pb.MediaAnalysis] | None = None, + explicit_all_channels_analysis: media_pb.MediaAnalysis | None = None, +) -> marketing_pb.MarketingAnalysis: + """Create a `MarketingAnalysis` for the given analysis period and tag.""" + media_analyses = ( + list(explicit_channel_analyses) + if explicit_channel_analyses + else (MEDIA_ANALYSES_BOTH_OUTCOMES + RF_ANALYSES_BOTH_OUTCOMES) + ) + media_analyses.append( + explicit_all_channels_analysis + if explicit_all_channels_analysis + else ALL_CHANNELS_ANALYSIS_BOTH_OUTCOMES + ) + + return marketing_pb.MarketingAnalysis( + date_interval=date_interval, + non_media_analyses=[baseline_analysis], + media_analyses=media_analyses, + ) + + +# All of the below test analyses data contain both media and R&F channels. + +ALL_TAG_MARKETING_ANALYSIS_BOTH_OUTCOMES = create_marketing_analysis( + date_interval=ALL_DATE_INTERVAL, + baseline_analysis=BASELINE_ANALYSIS_BOTH_OUTCOMES, +) +ALL_TAG_MARKETING_ANALYSIS_NONREVENUE = create_marketing_analysis( + date_interval=ALL_DATE_INTERVAL, + baseline_analysis=BASELINE_ANALYSIS_NONREVENUE, + explicit_channel_analyses=( + MEDIA_ANALYSES_NONREVENUE + RF_ANALYSES_NONREVENUE + ), +) + +DATED_MARKETING_ANALYSES_BOTH_OUTCOMES = [ + create_marketing_analysis( + date_interval=date_interval, + baseline_analysis=BASELINE_ANALYSIS_BOTH_OUTCOMES, + ) + for date_interval in DATE_INTERVALS +] +DATED_MARKETING_ANALYSES_NONREVENUE = [ + create_marketing_analysis( + date_interval=date_interval, + baseline_analysis=BASELINE_ANALYSIS_NONREVENUE, + explicit_channel_analyses=( + MEDIA_ANALYSES_NONREVENUE + RF_ANALYSES_NONREVENUE + ), + ) + for date_interval in DATE_INTERVALS +] + +MARKETING_ANALYSIS_LIST_BOTH_OUTCOMES = marketing_pb.MarketingAnalysisList( + marketing_analyses=( + [ALL_TAG_MARKETING_ANALYSIS_BOTH_OUTCOMES] + + DATED_MARKETING_ANALYSES_BOTH_OUTCOMES + ), +) + +MARKETING_ANALYSIS_LIST_NONREVENUE = marketing_pb.MarketingAnalysisList( + marketing_analyses=( + [ALL_TAG_MARKETING_ANALYSIS_NONREVENUE] + + DATED_MARKETING_ANALYSES_NONREVENUE + ), +) + + +# Incremental outcome grids (budget) are only relevant for non-RF media +# channels. + +INCREMENTAL_OUTCOME_GRID_FOO = budget_pb.IncrementalOutcomeGrid( + name="incremental outcome grid foo", + channel_cells=[ + budget_pb.IncrementalOutcomeGrid.ChannelCells( + channel_name=MEDIA_CHANNELS[0], + cells=[ + budget_pb.IncrementalOutcomeGrid.Cell( + spend=10000.0, + incremental_outcome=estimate_pb.Estimate(value=100.0), + ), + budget_pb.IncrementalOutcomeGrid.Cell( + spend=20000.0, + incremental_outcome=estimate_pb.Estimate(value=200.0), + ), + ], + ), + budget_pb.IncrementalOutcomeGrid.ChannelCells( + channel_name=MEDIA_CHANNELS[1], + cells=[ + budget_pb.IncrementalOutcomeGrid.Cell( + spend=10000.0, + incremental_outcome=estimate_pb.Estimate(value=100.0), + ), + budget_pb.IncrementalOutcomeGrid.Cell( + spend=20000.0, + incremental_outcome=estimate_pb.Estimate(value=200.0), + ), + ], + ), + ], +) + +INCREMENTAL_OUTCOME_GRID_BAR = budget_pb.IncrementalOutcomeGrid( + name="incremental outcome grid bar", + channel_cells=[ + budget_pb.IncrementalOutcomeGrid.ChannelCells( + channel_name=MEDIA_CHANNELS[0], + cells=[ + budget_pb.IncrementalOutcomeGrid.Cell( + spend=1000.0, + incremental_outcome=estimate_pb.Estimate(value=10.0), + ), + budget_pb.IncrementalOutcomeGrid.Cell( + spend=2000.0, + incremental_outcome=estimate_pb.Estimate(value=20.0), + ), + ], + ), + budget_pb.IncrementalOutcomeGrid.ChannelCells( + channel_name=MEDIA_CHANNELS[1], + cells=[ + budget_pb.IncrementalOutcomeGrid.Cell( + spend=1000.0, + incremental_outcome=estimate_pb.Estimate(value=10.0), + ), + budget_pb.IncrementalOutcomeGrid.Cell( + spend=2000.0, + incremental_outcome=estimate_pb.Estimate(value=20.0), + ), + ], + ), + ], +) + +# A fixed budget scenario for the entire time interval in the test data above. +BUDGET_OPTIMIZATION_SPEC_FIXED_ALL_DATES = budget_pb.BudgetOptimizationSpec( + date_interval=ALL_DATE_INTERVAL, + objective=target_metric_pb.TargetMetric.ROI, + fixed_budget_scenario=budget_pb.FixedBudgetScenario(total_budget=100000.0), + # No individual channel constraints. Expect implicit constraints: max budget + # applied for each channel. +) +BUDGET_OPTIMIZATION_RESULT_FIXED_BOTH_OUTCOMES = ( + budget_pb.BudgetOptimizationResult( + name="budget optimization result foo", + group_id="group-foo", + optimized_marketing_analysis=ALL_TAG_MARKETING_ANALYSIS_BOTH_OUTCOMES, + spec=BUDGET_OPTIMIZATION_SPEC_FIXED_ALL_DATES, + incremental_outcome_grid=INCREMENTAL_OUTCOME_GRID_FOO, + ) +) + +# A flexible budget scenario for the second time interval only. +BUDGET_OPTIMIZATION_SPEC_FLEX_SELECT_DATES = budget_pb.BudgetOptimizationSpec( + date_interval=DATE_INTERVALS[1], + objective=target_metric_pb.TargetMetric.KPI, + flexible_budget_scenario=budget_pb.FlexibleBudgetScenario( + total_budget_constraint=constraints_pb.BudgetConstraint( + min_budget=1000.0, + max_budget=2000.0, + ), + target_metric_constraints=[ + constraints_pb.TargetMetricConstraint( + target_metric=target_metric_pb.COST_PER_INCREMENTAL_KPI, + target_value=10.0, + ) + ], + ), + # Define explicit channel constraints. + channel_constraints=[ + budget_pb.ChannelConstraint( + channel_name=MEDIA_CHANNELS[0], + budget_constraint=constraints_pb.BudgetConstraint( + min_budget=1100.0, + max_budget=1500.0, + ), + ), + budget_pb.ChannelConstraint( + channel_name=MEDIA_CHANNELS[1], + budget_constraint=constraints_pb.BudgetConstraint( + min_budget=1000.0, + max_budget=1800.0, + ), + ), + ], +) +BUDGET_OPTIMIZATION_RESULT_FLEX_NONREV = budget_pb.BudgetOptimizationResult( + name="budget optimization result bar", + group_id="group-bar", + optimized_marketing_analysis=ALL_TAG_MARKETING_ANALYSIS_NONREVENUE, + spec=BUDGET_OPTIMIZATION_SPEC_FLEX_SELECT_DATES, + incremental_outcome_grid=INCREMENTAL_OUTCOME_GRID_BAR, +) + + +# Frequency outcome grids are only relevant for R&F media channels. + +FREQUENCY_OUTCOME_GRID_FOO = rf_pb.FrequencyOutcomeGrid( + name="frequency outcome grid foo", + channel_cells=[ + rf_pb.FrequencyOutcomeGrid.ChannelCells( + channel_name=RF_CHANNELS[0], + cells=[ + rf_pb.FrequencyOutcomeGrid.Cell( + reach_frequency=marketing_data_pb.ReachFrequency( + reach=10000, + average_frequency=1.0, + ), + outcome=estimate_pb.Estimate(value=100.0), + ), + rf_pb.FrequencyOutcomeGrid.Cell( + reach_frequency=marketing_data_pb.ReachFrequency( + reach=20000, + average_frequency=2.0, + ), + outcome=estimate_pb.Estimate(value=200.0), + ), + ], + ), + rf_pb.FrequencyOutcomeGrid.ChannelCells( + channel_name=RF_CHANNELS[1], + cells=[ + rf_pb.FrequencyOutcomeGrid.Cell( + reach_frequency=marketing_data_pb.ReachFrequency( + reach=10000, + average_frequency=1.0, + ), + outcome=estimate_pb.Estimate(value=100.0), + ), + rf_pb.FrequencyOutcomeGrid.Cell( + reach_frequency=marketing_data_pb.ReachFrequency( + reach=20000, + average_frequency=2.0, + ), + outcome=estimate_pb.Estimate(value=200.0), + ), + ], + ), + ], +) + +RF_OPTIMIZATION_SPEC_ALL_DATES = rf_pb.ReachFrequencyOptimizationSpec( + date_interval=ALL_DATE_INTERVAL, + objective=target_metric_pb.TargetMetric.KPI, + total_budget_constraint=constraints_pb.BudgetConstraint( + min_budget=100000.0, + max_budget=200000.0, + ), + rf_channel_constraints=[ + rf_pb.RfChannelConstraint( + channel_name=RF_CHANNELS[0], + frequency_constraint=constraints_pb.FrequencyConstraint( + max_frequency=5.0, + ), + ), + rf_pb.RfChannelConstraint( + channel_name=RF_CHANNELS[1], + frequency_constraint=constraints_pb.FrequencyConstraint( + min_frequency=1.3, + max_frequency=6.6, + ), + ), + ], +) + +RF_OPTIMIZATION_RESULT_FOO = rf_pb.ReachFrequencyOptimizationResult( + name="reach frequency optimization result foo", + group_id="group-foo", + spec=RF_OPTIMIZATION_SPEC_ALL_DATES, + optimized_channel_frequencies=[ + rf_pb.OptimizedChannelFrequency( + channel_name=RF_CHANNELS[0], + optimal_average_frequency=3.3, + ), + rf_pb.OptimizedChannelFrequency( + channel_name=RF_CHANNELS[1], + optimal_average_frequency=5.6, + ), + ], + optimized_marketing_analysis=ALL_TAG_MARKETING_ANALYSIS_BOTH_OUTCOMES, + frequency_outcome_grid=FREQUENCY_OUTCOME_GRID_FOO, +) diff --git a/scenarioplanner/linkingapi/__init__.py b/scenarioplanner/linkingapi/__init__.py new file mode 100644 index 000000000..3989904f5 --- /dev/null +++ b/scenarioplanner/linkingapi/__init__.py @@ -0,0 +1,47 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Builds Looker Studio report URLs. + +This package provides tools to construct URLs for Looker Studio reports that +embed data directly within the URL itself. This is achieved through the creation +of shareable, pre-configured reports without requiring a separate, pre-existing +data source. + +The primary functionality is exposed through the `url_generator` module. + +Typical Usage: + + 1. Use `url_generator.create_report_url()` to create the complete URL, based + on a `sheets.Spreadsheet` object. + +Example: + +```python +from lookerstudio.linkingapi import url_generator +from lookerstudio.converters import sheets + +# Generate the URL +looker_studio_report_url = url_generator.create_report_url( + url="some_url", + id="some_id", + sheet_id_by_tab_name={}, +) +# The `looker_studio_report_url` can now be shared to open a pre-populated +# report. +``` +""" + +from scenarioplanner.linkingapi import constants +from scenarioplanner.linkingapi import url_generator diff --git a/scenarioplanner/linkingapi/constants.py b/scenarioplanner/linkingapi/constants.py new file mode 100644 index 000000000..809046f61 --- /dev/null +++ b/scenarioplanner/linkingapi/constants.py @@ -0,0 +1,27 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants shared for the Linking API usage for the Meridian UI. + +Defines constants used in the URL generation process, such as API endpoints and +parameter names. +""" + +REPORT_TEMPLATE_ID = 'fbd3aeff-fc00-45fd-83f7-1ec5f21c9f56' +COMMUNITY_CONNECTOR_NAME = 'community' +COMMUNITY_CONNECTOR_ID = ( + 'AKfycbz-xdEN-GbTuQ9MjEddS-64wLgXwMMTp9a4zFE4PO_kwT6wDgZPsN4Y19oKmLLHD6xk' +) +SHEETS_CONNECTOR_NAME = 'googleSheets' +GA4_MEASUREMENT_ID = 'G-R6C81BNHJ4' diff --git a/scenarioplanner/linkingapi/url_generator.py b/scenarioplanner/linkingapi/url_generator.py new file mode 100644 index 000000000..e418f1086 --- /dev/null +++ b/scenarioplanner/linkingapi/url_generator.py @@ -0,0 +1,136 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility library for generating a Looker Studio report and outputting the URL. + +Contains the core logic for constructing the Looker Studio report URL, including +setting the correct parameters. + +This library requires authentication. + +* If you're developing locally, set up Application Default Credentials (ADC) +in + your local environment: + + + +* If you're working in Colab, run the following command in a cell to + authenticate: + + ```python + from google.colab import auth + auth.authenticate_user() + ``` + + This command opens a window where you can complete the authentication. +""" + +import urllib.parse +import warnings +from scenarioplanner.converters import sheets +from scenarioplanner.converters.dataframe import constants as dc +from scenarioplanner.linkingapi import constants + + +def create_report_url( + spreadsheet: sheets.Spreadsheet, data_sharing_opt_in: bool = False +) -> str: + """Creates a Looker Studio report URL based on the given spreadsheet. + + If there are some sheet tabs that are not in `spreadsheet`, the report will + display its demo data. + + Args: + spreadsheet: The spreadsheet object that contains the data to visualize in a + Looker Studio report. + data_sharing_opt_in: Whether the user has opted in to share data. + + Returns: + The URL of the Looker Studio report. + """ + params = [] + + encoded_sheet_url = urllib.parse.quote_plus(spreadsheet.url) + data_sharing_opt_in_str = str(data_sharing_opt_in).lower() + + params.append(f'c.reportId={constants.REPORT_TEMPLATE_ID}') + params.append(f'r.measurementId={constants.GA4_MEASUREMENT_ID}') + + if dc.OPTIMIZATION_SPECS in spreadsheet.sheet_id_by_tab_name: + params.append(f'ds.dscc.connector={constants.COMMUNITY_CONNECTOR_NAME}') + params.append(f'ds.dscc.connectorId={constants.COMMUNITY_CONNECTOR_ID}') + params.append(f'ds.dscc.spreadsheetUrl={encoded_sheet_url}') + params.append(f'ds.dscc.dataSharingOptIn={data_sharing_opt_in_str}') + else: + warnings.warn( + 'No optimization specs found in the spreadsheet. The report will' + ' display its demo data.' + ) + + params.append('ds.*.refreshFields=false') + params.append('ds.*.keepDatasourceName=true') + params.append(f'ds.*.connector={constants.SHEETS_CONNECTOR_NAME}') + params.append(f'ds.*.spreadsheetId={spreadsheet.id}') + + if dc.MODEL_FIT in spreadsheet.sheet_id_by_tab_name: + worksheet_id = spreadsheet.sheet_id_by_tab_name[dc.MODEL_FIT] + params.append(f'ds.ds_model_fit.worksheetId={worksheet_id}') + else: + warnings.warn( + 'No model fit found in the spreadsheet. The report will' + ' display its demo data.' + ) + + if dc.MODEL_DIAGNOSTICS in spreadsheet.sheet_id_by_tab_name: + worksheet_id = spreadsheet.sheet_id_by_tab_name[dc.MODEL_DIAGNOSTICS] + params.append(f'ds.ds_model_diag.worksheetId={worksheet_id}') + else: + warnings.warn( + 'No model diagnostics found in the spreadsheet. The report will' + ' display its demo data.' + ) + + if dc.MEDIA_OUTCOME in spreadsheet.sheet_id_by_tab_name: + worksheet_id = spreadsheet.sheet_id_by_tab_name[dc.MEDIA_OUTCOME] + params.append(f'ds.ds_outcome.worksheetId={worksheet_id}') + else: + warnings.warn( + 'No media outcome found in the spreadsheet. The report will' + ' display its demo data.' + ) + + if dc.MEDIA_SPEND in spreadsheet.sheet_id_by_tab_name: + worksheet_id = spreadsheet.sheet_id_by_tab_name[dc.MEDIA_SPEND] + params.append(f'ds.ds_spend.worksheetId={worksheet_id}') + else: + warnings.warn( + 'No media spend found in the spreadsheet. The report will' + ' display its demo data.' + ) + + if dc.MEDIA_ROI in spreadsheet.sheet_id_by_tab_name: + worksheet_id = spreadsheet.sheet_id_by_tab_name[dc.MEDIA_ROI] + params.append(f'ds.ds_roi.worksheetId={worksheet_id}') + else: + warnings.warn( + 'No media ROI found in the spreadsheet. The report will' + ' display its demo data.' + ) + + joined_params = '&'.join(params) + report_url = ( + 'https://lookerstudio.google.com/reporting/create?' + joined_params + ) + + return report_url diff --git a/scenarioplanner/linkingapi/url_generator_test.py b/scenarioplanner/linkingapi/url_generator_test.py new file mode 100644 index 000000000..d7271446e --- /dev/null +++ b/scenarioplanner/linkingapi/url_generator_test.py @@ -0,0 +1,297 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from unittest import mock + +from absl.testing import absltest +from scenarioplanner.converters import sheets +from scenarioplanner.converters.dataframe import constants as dc +from scenarioplanner.linkingapi import constants +from scenarioplanner.linkingapi import url_generator + + +_SPREADSHEET_URL = "https://docs.google.com/spreadsheets/d/test_id" +_ENCODED_SPREADSHEET_URL = ( + "https%3A%2F%2Fdocs.google.com%2Fspreadsheets%2Fd%2Ftest_id" +) +_SPREADSHEET_ID = "test_id" + +_REPORT_TEMPLATE_ID = "report_template_id" +_COMMUNITY_CONNECTOR_NAME = "community" +_COMMUNITY_CONNECTOR_ID = "cc_id" +_SHEETS_CONNECTOR_NAME = "google_sheets" +_GA4_MEASUREMENT_ID = "ga4_measurement_id" + + +class UrlGeneratorTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.base_spreadsheet = sheets.Spreadsheet( + url=_SPREADSHEET_URL, + id=_SPREADSHEET_ID, + sheet_id_by_tab_name={}, + ) + self.base_url = ( + "https://lookerstudio.google.com/reporting/create?" + + f"c.reportId={_REPORT_TEMPLATE_ID}&" + + f"r.measurementId={_GA4_MEASUREMENT_ID}&" + ) + + self.enter_context( + mock.patch.object( + constants, + "REPORT_TEMPLATE_ID", + new=_REPORT_TEMPLATE_ID, + ) + ) + self.enter_context( + mock.patch.object( + constants, + "COMMUNITY_CONNECTOR_NAME", + new=_COMMUNITY_CONNECTOR_NAME, + ) + ) + self.enter_context( + mock.patch.object( + constants, + "COMMUNITY_CONNECTOR_ID", + new=_COMMUNITY_CONNECTOR_ID, + ) + ) + self.enter_context( + mock.patch.object( + constants, + "SHEETS_CONNECTOR_NAME", + new=_SHEETS_CONNECTOR_NAME, + ) + ) + self.enter_context( + mock.patch.object( + constants, + "GA4_MEASUREMENT_ID", + new=_GA4_MEASUREMENT_ID, + ) + ) + + def test_create_report_url_when_empty_spreadsheet(self): + expected = ( + self.base_url + + "ds.*.refreshFields=false&" + + "ds.*.keepDatasourceName=true&" + + f"ds.*.connector={_SHEETS_CONNECTOR_NAME}&" + + f"ds.*.spreadsheetId={_SPREADSHEET_ID}" + ) + + self.assertEqual( + url_generator.create_report_url(self.base_spreadsheet), expected + ) + + def test_create_report_url_with_community_connector(self): + spreadsheet = dataclasses.replace( + self.base_spreadsheet, + sheet_id_by_tab_name={ + dc.OPTIMIZATION_SPECS: 1, + dc.OPTIMIZATION_RESULTS: 2, + dc.OPTIMIZATION_RESPONSE_CURVES: 3, + dc.RF_OPTIMIZATION_SPECS: 4, + dc.RF_OPTIMIZATION_RESULTS: 5, + }, + ) + expected = ( + self.base_url + + f"ds.dscc.connector={_COMMUNITY_CONNECTOR_NAME}&" + + f"ds.dscc.connectorId={_COMMUNITY_CONNECTOR_ID}&" + + f"ds.dscc.spreadsheetUrl={_ENCODED_SPREADSHEET_URL}&" + + "ds.dscc.dataSharingOptIn=false&" + + "ds.*.refreshFields=false&" + + "ds.*.keepDatasourceName=true&" + + f"ds.*.connector={_SHEETS_CONNECTOR_NAME}&" + + f"ds.*.spreadsheetId={_SPREADSHEET_ID}" + ) + + self.assertEqual(url_generator.create_report_url(spreadsheet), expected) + + def test_create_report_url_with_community_connector_and_data_sharing_opt_in( + self, + ): + spreadsheet = dataclasses.replace( + self.base_spreadsheet, + sheet_id_by_tab_name={ + dc.OPTIMIZATION_SPECS: 1, + dc.OPTIMIZATION_RESULTS: 2, + dc.OPTIMIZATION_RESPONSE_CURVES: 3, + dc.RF_OPTIMIZATION_SPECS: 4, + dc.RF_OPTIMIZATION_RESULTS: 5, + }, + ) + expected = ( + self.base_url + + f"ds.dscc.connector={_COMMUNITY_CONNECTOR_NAME}&" + + f"ds.dscc.connectorId={_COMMUNITY_CONNECTOR_ID}&" + + f"ds.dscc.spreadsheetUrl={_ENCODED_SPREADSHEET_URL}&" + + "ds.dscc.dataSharingOptIn=true&" + + "ds.*.refreshFields=false&" + + "ds.*.keepDatasourceName=true&" + + f"ds.*.connector={_SHEETS_CONNECTOR_NAME}&" + + f"ds.*.spreadsheetId={_SPREADSHEET_ID}" + ) + + self.assertEqual( + url_generator.create_report_url(spreadsheet, data_sharing_opt_in=True), + expected, + ) + + def test_create_report_url_warns_when_no_community_connector(self): + with self.assertWarnsRegex( + UserWarning, + "No optimization specs found in the spreadsheet. The report will" + " display its demo data.", + ): + url_generator.create_report_url(self.base_spreadsheet) + + def test_create_report_url_with_model_fit(self): + spreadsheet = dataclasses.replace( + self.base_spreadsheet, + sheet_id_by_tab_name={ + dc.MODEL_FIT: 1, + }, + ) + expected = ( + self.base_url + + "ds.*.refreshFields=false&" + + "ds.*.keepDatasourceName=true&" + + f"ds.*.connector={_SHEETS_CONNECTOR_NAME}&" + + f"ds.*.spreadsheetId={_SPREADSHEET_ID}&" + + "ds.ds_model_fit.worksheetId=1" + ) + + self.assertEqual(url_generator.create_report_url(spreadsheet), expected) + + def test_create_report_url_warns_when_no_model_fit(self): + with self.assertWarnsRegex( + UserWarning, + "No model fit found in the spreadsheet. The report will" + " display its demo data.", + ): + url_generator.create_report_url(self.base_spreadsheet) + + def test_create_report_url_with_model_diagnostics(self): + spreadsheet = dataclasses.replace( + self.base_spreadsheet, + sheet_id_by_tab_name={ + dc.MODEL_DIAGNOSTICS: 1, + }, + ) + expected = ( + self.base_url + + "ds.*.refreshFields=false&" + + "ds.*.keepDatasourceName=true&" + + f"ds.*.connector={_SHEETS_CONNECTOR_NAME}&" + + f"ds.*.spreadsheetId={_SPREADSHEET_ID}&" + + "ds.ds_model_diag.worksheetId=1" + ) + + self.assertEqual(url_generator.create_report_url(spreadsheet), expected) + + def test_create_report_url_warns_when_no_model_diagnostics(self): + with self.assertWarnsRegex( + UserWarning, + "No model diagnostics found in the spreadsheet. The report will" + " display its demo data.", + ): + url_generator.create_report_url(self.base_spreadsheet) + + def test_create_report_url_with_media_outcome(self): + spreadsheet = dataclasses.replace( + self.base_spreadsheet, + sheet_id_by_tab_name={ + dc.MEDIA_OUTCOME: 1, + }, + ) + expected = ( + self.base_url + + "ds.*.refreshFields=false&" + + "ds.*.keepDatasourceName=true&" + + f"ds.*.connector={_SHEETS_CONNECTOR_NAME}&" + + f"ds.*.spreadsheetId={_SPREADSHEET_ID}&" + + "ds.ds_outcome.worksheetId=1" + ) + + self.assertEqual(url_generator.create_report_url(spreadsheet), expected) + + def test_create_report_url_warns_when_no_media_outcome(self): + with self.assertWarnsRegex( + UserWarning, + "No media outcome found in the spreadsheet. The report will" + " display its demo data.", + ): + url_generator.create_report_url(self.base_spreadsheet) + + def test_create_report_url_with_media_spend(self): + spreadsheet = dataclasses.replace( + self.base_spreadsheet, + sheet_id_by_tab_name={ + dc.MEDIA_SPEND: 1, + }, + ) + expected = ( + self.base_url + + "ds.*.refreshFields=false&" + + "ds.*.keepDatasourceName=true&" + + f"ds.*.connector={_SHEETS_CONNECTOR_NAME}&" + + f"ds.*.spreadsheetId={_SPREADSHEET_ID}&" + + "ds.ds_spend.worksheetId=1" + ) + + self.assertEqual(url_generator.create_report_url(spreadsheet), expected) + + def test_create_report_url_warns_when_no_media_spend(self): + with self.assertWarnsRegex( + UserWarning, + "No media spend found in the spreadsheet. The report will" + " display its demo data.", + ): + url_generator.create_report_url(self.base_spreadsheet) + + def test_create_report_url_with_media_roi(self): + spreadsheet = dataclasses.replace( + self.base_spreadsheet, + sheet_id_by_tab_name={ + dc.MEDIA_ROI: 1, + }, + ) + expected = ( + self.base_url + + "ds.*.refreshFields=false&" + + "ds.*.keepDatasourceName=true&" + + f"ds.*.connector={_SHEETS_CONNECTOR_NAME}&" + + f"ds.*.spreadsheetId={_SPREADSHEET_ID}&" + + "ds.ds_roi.worksheetId=1" + ) + + self.assertEqual(url_generator.create_report_url(spreadsheet), expected) + + def test_create_report_url_warns_when_no_media_roi(self): + with self.assertWarnsRegex( + UserWarning, + "No media ROI found in the spreadsheet. The report will" + " display its demo data.", + ): + url_generator.create_report_url(self.base_spreadsheet) + + +if __name__ == "__main__": + absltest.main() diff --git a/scenarioplanner/mmm_ui_proto_generator.py b/scenarioplanner/mmm_ui_proto_generator.py new file mode 100644 index 000000000..aebac7739 --- /dev/null +++ b/scenarioplanner/mmm_ui_proto_generator.py @@ -0,0 +1,354 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generates an `Mmm` (Marketing Mix Model) proto for Meridian UI. + +The MMM proto schema contains parts collected from the core model as well as +analysis results from trained model processors. +""" + +import abc +from collections.abc import Collection, Sequence +import dataclasses +import datetime +from typing import TypeVar +import uuid +import warnings + +from meridian.analysis import optimizer +from meridian.data import time_coordinates as tc +from meridian.model import model +from mmm.v1 import mmm_pb2 as mmm_pb +from scenarioplanner.converters.dataframe import constants as converter_constants +from schema import mmm_proto_generator +from schema.processors import budget_optimization_processor as bop +from schema.processors import marketing_processor +from schema.processors import model_fit_processor +from schema.processors import model_processor +from schema.processors import reach_frequency_optimization_processor as rfop +from schema.utils import date_range_bucketing + + +__all__ = [ + "MmmUiProtoGenerator", + "create_mmm_ui_data_proto", + "create_tag", +] + + +_ALLOWED_SPEC_TYPES_FOR_UI = frozenset({ + model_fit_processor.ModelFitSpec, + marketing_processor.MarketingAnalysisSpec, + bop.BudgetOptimizationSpec, +}) + +_SPEC_TYPES_CREATE_SUBSPECS = frozenset({ + marketing_processor.MarketingAnalysisSpec, + bop.BudgetOptimizationSpec, + rfop.ReachFrequencyOptimizationSpec, +}) + +_DATE_RANGE_GENERATORS = frozenset({ + date_range_bucketing.MonthlyDateRangeGenerator, + date_range_bucketing.QuarterlyDateRangeGenerator, + date_range_bucketing.YearlyDateRangeGenerator, +}) + +SpecType = TypeVar("SpecType", bound=model_processor.Spec) +DatedSpecType = TypeVar("DatedSpecType", bound=model_processor.DatedSpec) +OptimizationSpecType = TypeVar( + "OptimizationSpecType", bound=model_processor.OptimizationSpec +) + +_DERIVED_RF_OPT_NAME_PREFIX = "derived RF optimization from " +_DERIVED_RF_OPT_GRID_NAME_PREFIX = "derived_from_" + + +class MmmUiProtoGenerator: + """Creates `Mmm` proto for the Meridian Scenario Planner UI (Looker Studio). + + Currently, it only accepts specs for Model Fit, Marketing Analysis, and Budget + Optimization, but not stand-alone Reach Frequency Optimization specs. + Reach Frequency Optimization spec will be derived from the Budget Optimization + spec; this is done so that we can structurally pair them. + + Attributes: + mmm: A trained Meridian model. A trained model has its posterior + distributions already sampled. + specs: A sequence of specs that specify the analyses to run on the model. + model_id: An optional model identifier. + time_breakdown_generators: A list of generators that break down the given + specs by automatically generated time buckets. Currently, this time period + breakdown is only done on Marketing Analysis specs and Budget Optimization + specs. All other specs are processed in their original forms. The set of + default bucketers break down sub-specs with the following time periods: + [All (original spec's time period), Yearly, Quarterly, Monthly] + """ + + def __init__( + self, + mmm: model.Meridian, + specs: Sequence[SpecType], + model_id: str = "", + time_breakdown_generators: Collection[ + type[date_range_bucketing.DateRangeBucketer] + ] = _DATE_RANGE_GENERATORS, + ): + self._mmm = mmm + self._input_specs = specs + self._model_id = model_id + self._time_breakdown_generators = time_breakdown_generators + + @property + def _time_coordinates(self) -> tc.TimeCoordinates: + return self._mmm.input_data.time_coordinates + + def __call__(self) -> mmm_pb.Mmm: + """Creates `Mmm` proto for the Meridian Scenario Planner UI (Looker Studio). + + Returns: + A proto containing the model kernel at rest and its analysis results given + user specs. + """ + seen_group_ids = set() + + copy_specs = [] + for spec in self._input_specs: + if not any(isinstance(spec, t) for t in _ALLOWED_SPEC_TYPES_FOR_UI): + raise ValueError(f"Unsupported spec type: {spec.__class__.__name__}") + + if isinstance(spec, bop.BudgetOptimizationSpec): + group_id = spec.group_id + if not group_id: + group_id = str(uuid.uuid4()) + copy_specs.append(dataclasses.replace(spec, group_id=group_id)) + else: + if group_id in seen_group_ids: + raise ValueError( + f"Duplicate group ID found: {group_id}. Please provide a unique" + " group ID for each Budget Optimization spec." + ) + seen_group_ids.add(group_id) + copy_specs.append(spec) + + # If there are RF channels, derive a RF optimization spec from the + # Budget Optimization spec. + if self._mmm.input_data.rf_channel is not None: + copy_specs.append( + self._derive_rf_opt_spec_from_budget_opt_spec(copy_specs[-1]) + ) + else: + copy_specs.append(spec) + + sub_specs = [] + for spec in copy_specs: + to_create_subspecs = self._time_breakdown_generators and any( + isinstance(spec, t) for t in _SPEC_TYPES_CREATE_SUBSPECS + ) + + if to_create_subspecs: + dates = self._enumerate_dates_open_end(spec) + sub_specs.extend( + _create_subspecs(spec, dates, self._time_breakdown_generators) + ) + else: + sub_specs.append(spec) + + return mmm_proto_generator.create_mmm_proto( + self._mmm, + sub_specs, + model_id=self._model_id, + ) + + def _derive_rf_opt_spec_from_budget_opt_spec( + self, + budget_opt_spec: bop.BudgetOptimizationSpec, + ) -> rfop.ReachFrequencyOptimizationSpec: + """Derives a ReachFrequencyOptimizationSpec from a BudgetOptimizationSpec.""" + rf_opt_name = ( + f"{_DERIVED_RF_OPT_NAME_PREFIX}{budget_opt_spec.optimization_name}" + ) + rf_opt_grid_name = ( + f"{_DERIVED_RF_OPT_GRID_NAME_PREFIX}{budget_opt_spec.optimization_name}" + ) + + return rfop.ReachFrequencyOptimizationSpec( + start_date=budget_opt_spec.start_date, + end_date=budget_opt_spec.end_date, + date_interval_tag=budget_opt_spec.date_interval_tag, + optimization_name=rf_opt_name, + grid_name=rf_opt_grid_name, + group_id=budget_opt_spec.group_id, + confidence_level=budget_opt_spec.confidence_level, + ) + + def _enumerate_dates_open_end( + self, spec: DatedSpecType + ) -> list[datetime.date]: + """Enumerates date points with an open end date. + + The date points are enumerated from the data's time coordinates based on the + spec's start and end dates. The last date point is the exclusive end date as + same as the spec's end date, if specified. + + Args: + spec: A dated spec. + + Returns: + A list of date points. + """ + inclusive_date_strs = spec.resolver( + self._mmm.input_data.time_coordinates + ).resolve_to_enumerated_selected_times() + + if inclusive_date_strs is None: + dates = self._time_coordinates.all_dates + else: + dates = [tc.normalize_date(date_str) for date_str in inclusive_date_strs] + + # If the end date is not specified, compute the exclusive end date based on + # the last date in the time coordinates. + exclusive_end_date = spec.end_date or dates[-1] + datetime.timedelta( + days=self._time_coordinates.interval_days + ) + + dates.append(exclusive_end_date) + + return dates + + +def create_mmm_ui_data_proto( + mmm: model.Meridian, + specs: Sequence[SpecType], + model_id: str = "", + time_breakdown_generators: Collection[ + type[date_range_bucketing.DateRangeBucketer] + ] = _DATE_RANGE_GENERATORS, +) -> mmm_pb.Mmm: + """Creates `Mmm` proto for the Meridian Scenario Planner UI (Looker Studio). + + Currently, it only accepts specs for Model Fit, Marketing Analysis, and Budget + Optimization, but not stand-alone Reach Frequency Optimization specs. + Reach Frequency Optimization spec will be derived from the Budget Optimization + spec; this is done so that we can structurally pair them. + + Args: + mmm: A trained Meridian model. A trained model has its posterior + distributions already sampled. + specs: A sequence of specs that specify the analyses to run on the model. + model_id: An optional model identifier. + time_breakdown_generators: A list of generators that break down the given + specs by automatically generated time buckets. Currently, this time period + breakdown is only done on Marketing Analysis specs and Budget Optimization + specs. All other specs are processed in their original forms. The set of + default bucketers break down sub-specs with the following time periods: + [All (original spec's time period), Yearly, Quarterly, Monthly] + + Returns: + A proto containing the model kernel at rest and its analysis results given + user specs. + """ + return MmmUiProtoGenerator( + mmm, + specs, + model_id, + time_breakdown_generators, + )() + + +def create_tag( + generator_class: type[abc.ABC], start_date: datetime.date +) -> str: + """Creates a human-readable tag for a spec.""" + if generator_class == date_range_bucketing.YearlyDateRangeGenerator: + return f"Y{start_date.year}" + elif generator_class == date_range_bucketing.QuarterlyDateRangeGenerator: + return f"Y{start_date.year} Q{(start_date.month - 1) // 3 + 1}" + elif generator_class == date_range_bucketing.MonthlyDateRangeGenerator: + return f"Y{start_date.year} {start_date.strftime('%b')}" + else: + raise ValueError(f"Unsupported generator class: {generator_class}") + + +def _normalize_optimization_spec_time_info( + spec: OptimizationSpecType, + date_interval_tag: str, +) -> OptimizationSpecType: + """Adds time info to an optimization spec.""" + formatted_date_interval_tag = date_interval_tag.replace(r" ", "_") + return dataclasses.replace( + spec, + group_id=f"{spec.group_id}:{date_interval_tag}", + optimization_name=f"{spec.optimization_name} for {date_interval_tag}", + grid_name=f"{spec.grid_name}_{formatted_date_interval_tag}", + ) + + +def _create_subspecs( + spec: DatedSpecType, + date_range: list[datetime.date], + time_breakdown_generators: Collection[ + type[date_range_bucketing.DateRangeBucketer] + ], +) -> list[DatedSpecType]: + """Breaks down a spec into sub-specs for each time bucket.""" + specs = [] + + all_period_spec = dataclasses.replace( + spec, + date_interval_tag=converter_constants.ANALYSIS_TAG_ALL, + ) + if isinstance(all_period_spec, model_processor.OptimizationSpec): + all_period_spec = _normalize_optimization_spec_time_info( + all_period_spec, converter_constants.ANALYSIS_TAG_ALL + ) + specs.append(all_period_spec) + + for generator_class in time_breakdown_generators: + generator = generator_class(date_range) # pytype: disable=not-instantiable + date_intervals = generator.generate_date_intervals() + for start_date, end_date in date_intervals: + date_interval_tag = create_tag(generator_class, start_date) + new_spec = dataclasses.replace( + spec, + start_date=start_date, + end_date=end_date, + date_interval_tag=date_interval_tag, + ) + + if isinstance(new_spec, model_processor.OptimizationSpec): + new_spec = _normalize_optimization_spec_time_info( + new_spec, date_interval_tag + ) + + if ( + isinstance(new_spec, bop.BudgetOptimizationSpec) + and isinstance(new_spec.scenario, optimizer.FixedBudgetScenario) + and new_spec.scenario.total_budget is not None + ): + # TODO: The budget amount should be adjusted based on the + # budget specified in the `all_period_spec` and the historical spend + # at the time period. + new_spec = dataclasses.replace( + new_spec, + scenario=optimizer.FixedBudgetScenario(total_budget=None), + ) + warnings.warn( + "Using historical spend for budget optimization spec at the" + f" period of {date_interval_tag}", + ) + + specs.append(new_spec) + + return specs diff --git a/schema/__init__.py b/schema/__init__.py new file mode 100644 index 000000000..e01cfc373 --- /dev/null +++ b/schema/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module containing MMM schema library.""" + +from schema import mmm_proto_generator +from schema import model_consumer +from schema import processors +from schema import serde +from schema import utils diff --git a/schema/mmm_proto_generator.py b/schema/mmm_proto_generator.py new file mode 100644 index 000000000..e8d010efc --- /dev/null +++ b/schema/mmm_proto_generator.py @@ -0,0 +1,71 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generates an `Mmm` (Marketing Mix Model) proto for Meridian. + +The MMM proto schema contains parts collected from the core model as well as +analysis results from trained model processors. +""" + +from collections.abc import Sequence +from typing import TypeVar + +from meridian.model import model +from mmm.v1 import mmm_pb2 as mmm_pb +from schema import model_consumer +from schema.processors import budget_optimization_processor +from schema.processors import marketing_processor +from schema.processors import model_fit_processor +from schema.processors import model_processor +from schema.processors import reach_frequency_optimization_processor + + +__all__ = [ + "create_mmm_proto", +] + + +_TYPES = ( + model_fit_processor.ModelFitProcessor, + marketing_processor.MarketingProcessor, + budget_optimization_processor.BudgetOptimizationProcessor, + reach_frequency_optimization_processor.ReachFrequencyOptimizationProcessor, +) + +SpecType = TypeVar("SpecType", bound=model_processor.Spec) +DatedSpecType = TypeVar("DatedSpecType", bound=model_processor.DatedSpec) +OptimizationSpecType = TypeVar( + "OptimizationSpecType", bound=model_processor.OptimizationSpec +) + + +def create_mmm_proto( + mmm: model.Meridian, + specs: Sequence[SpecType], + model_id: str = "", +) -> mmm_pb.Mmm: + """Creates a model schema and analyses for various time buckets. + + Args: + mmm: A trained Meridian model. A trained model has its posterior + distributions already sampled. + specs: A sequence of specs that specify the analyses to run on the model. + model_id: An optional model identifier. + + Returns: + A proto containing the model kernel at rest and its analysis results given + user specs. + """ + consumer = model_consumer.ModelConsumer(_TYPES) + return consumer(mmm, specs, model_id) diff --git a/schema/mmm_proto_generator_test.py b/schema/mmm_proto_generator_test.py new file mode 100644 index 000000000..e4882bdf0 --- /dev/null +++ b/schema/mmm_proto_generator_test.py @@ -0,0 +1,153 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from unittest import mock + +from absl.testing import absltest +from meridian.analysis import optimizer +from mmm.v1.model import mmm_kernel_pb2 as kernel_pb +from schema import mmm_proto_generator +from schema import test_data as td +from schema.processors import budget_optimization_processor +from schema.processors import marketing_processor +from schema.processors import model_fit_processor +from schema.processors import model_processor +from schema.processors import reach_frequency_optimization_processor as rf_opt_processor +from schema.serde import meridian_serde + +from tensorflow.python.util.protobuf import compare + + +_STUBBED_PROCESSORS = [ + td.FakeModelFitProcessor, + td.FakeBudgetOptimizationProcessor, + td.FakeReachFrequencyOptimizationProcessor, + td.FakeMarketingProcessor, +] + + +class MmmProtoGeneratorTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.mock_mmm = mock.MagicMock() + + # Patch the trained model class. + self.mock_trained_model = self.enter_context( + mock.patch.object(model_processor, 'TrainedModel', autospec=True) + )(mmm=self.mock_mmm) + self.enter_context( + mock.patch.object( + meridian_serde.MeridianSerde, + 'serialize', + autospec=True, + return_value=kernel_pb.MmmKernel(), + ) + ) + + self.enter_context( + mock.patch.object( + mmm_proto_generator, + '_TYPES', + new=_STUBBED_PROCESSORS, + ) + ) + + def test_create_mmm_proto_populates_model_kernel(self): + output = mmm_proto_generator.create_mmm_proto(self.mock_mmm, []) + self.assertTrue(output.HasField('mmm_kernel')) + + def test_create_mmm_proto_populates_each_processor_output(self): + output = mmm_proto_generator.create_mmm_proto( + self.mock_mmm, + [ + model_fit_processor.ModelFitSpec(), + budget_optimization_processor.BudgetOptimizationSpec( + start_date=datetime.date(2023, 1, 2), + end_date=datetime.date(2023, 2, 6), + optimization_name='budget optimization', + grid_name='youtube_campaign', + scenario=optimizer.FixedBudgetScenario(1), + ), + rf_opt_processor.ReachFrequencyOptimizationSpec( + start_date=datetime.date(2023, 1, 2), + end_date=datetime.date(2023, 2, 6), + optimization_name='RF optimization', + grid_name='RF optimization grid', + ), + marketing_processor.MarketingAnalysisSpec( + start_date=datetime.date(2023, 1, 2), + end_date=datetime.date(2023, 2, 6), + media_summary_spec=marketing_processor.MediaSummarySpec(), + ), + ], + ) + with self.subTest('ModelFit'): + self.assertTrue(output.HasField('model_fit')) + with self.subTest('MarketingAnalysisList'): + self.assertTrue(output.HasField('marketing_analysis_list')) + with self.subTest('MarketingOptimization'): + self.assertTrue(output.HasField('marketing_optimization')) + with self.subTest('BudgetOptimization'): + self.assertTrue( + output.marketing_optimization.HasField('budget_optimization') + ) + with self.subTest('ReachFrequencyOptimization'): + self.assertTrue( + output.marketing_optimization.HasField('reach_frequency_optimization') + ) + + def test_create_mmm_proto_processes_multiple_specs_of_the_same_type(self): + expected_marketing_analysis_specs = [ + marketing_processor.MarketingAnalysisSpec( + start_date=datetime.date(2022, 11, 21), + end_date=datetime.date(2023, 1, 16), + media_summary_spec=marketing_processor.MediaSummarySpec(), + ), + marketing_processor.MarketingAnalysisSpec( + start_date=datetime.date(2023, 3, 27), + end_date=datetime.date(2023, 4, 10), + media_summary_spec=marketing_processor.MediaSummarySpec(), + ), + ] + output = mmm_proto_generator.create_mmm_proto( + self.mock_mmm, + [ + model_fit_processor.ModelFitSpec(), + marketing_processor.MarketingAnalysisSpec( + start_date=datetime.date(2022, 11, 21), + end_date=datetime.date(2023, 1, 16), + media_summary_spec=marketing_processor.MediaSummarySpec(), + ), + marketing_processor.MarketingAnalysisSpec( + start_date=datetime.date(2023, 3, 27), + end_date=datetime.date(2023, 4, 10), + media_summary_spec=marketing_processor.MediaSummarySpec(), + ), + ], + ) + self.assertTrue(output.HasField('model_fit')) + self.assertTrue(output.HasField('marketing_analysis_list')) + compare.assertProtoEqual( + self, + output.marketing_analysis_list, + td.FakeMarketingProcessor(self.mock_trained_model).execute( + expected_marketing_analysis_specs + ), + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/schema/model_consumer.py b/schema/model_consumer.py new file mode 100644 index 000000000..3f8a6b364 --- /dev/null +++ b/schema/model_consumer.py @@ -0,0 +1,133 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Consumes a trained Meridian model and produces an `Mmm` proto. + +The `Mmm` proto contains parts collected from the core model as well as +analysis results from trained model processors. +""" + +from collections.abc import Mapping, Sequence +import functools +import inspect +from typing import Any, Generic, TypeAlias, TypeVar + +from meridian.model import model +from mmm.v1 import mmm_pb2 as mmm_pb +from schema.processors import model_kernel_processor +from schema.processors import model_processor + + +__all__ = [ + "ModelConsumer", +] + + +SpecType: TypeAlias = type[model_processor.Spec] +ProcType = TypeVar("ProcType", bound=type[model_processor.ModelProcessor]) + + +class ModelConsumer(Generic[ProcType]): + """Consumes a trained Meridian model and produces an `Mmm` proto. + + Attributes: + model_processors: A preset list of model processor types. + """ + + def __init__( + self, + model_processors_classes: Sequence[ProcType], + ): + self._model_processors_classes = model_processors_classes + + @functools.cached_property + def specs_to_processors_classes( + self, + ) -> dict[SpecType, ProcType]: + """Returns a mapping of spec types to their corresponding processor types. + + Raises: + ValueError: If multiple model processors are found for the same spec type. + """ + specs_to_processors_classes = {} + for processor_class in self._model_processors_classes: + if ( + specs_to_processors_classes.get(processor_class.spec_type()) + is not None + ): + raise ValueError( + "Multiple model processors found for spec type:" + f" {processor_class.spec_type()}" + ) + specs_to_processors_classes[processor_class.spec_type()] = processor_class + return specs_to_processors_classes + + def __call__( + self, + mmm: model.Meridian, + specs: Sequence[model_processor.Spec], + model_id: str = "", + ) -> mmm_pb.Mmm: + """Produces an `Mmm` schema for the model along with its analyses results. + + Args: + mmm: A trained Meridian model. A trained model has its posterior + distributions already sampled. + specs: A sequence of specs that specify the analyses to run on the model. + Specs of the same type will be grouped together and executed together by + the corresponding model processor. + model_id: An optional model identifier. + + Returns: + A proto containing the model kernel at rest and its analysis results. + """ + + # Group specs by their type. + specs_by_type = {} + for spec in specs: + specs_by_type.setdefault(spec.__class__, []).append(spec) + + tmodel = model_processor.TrainedModel(mmm) + processor_params = { + "trained_model": tmodel, + } + + output = mmm_pb.Mmm() + # Attach the model kernel to the Mmm proto. + model_kernel_processor.ModelKernelProcessor(mmm, model_id)(output) + + # Perform analysis or optimization. + for spec_type, specs in specs_by_type.items(): + processor_type = self.specs_to_processors_classes[spec_type] + processor = _create_processor(processor_type, processor_params) + # Attach the output of the processor to the output proto. + processor(specs, output) + + return output + + +def _create_processor( + processor_type: ProcType, + processor_params: Mapping[str, Any], +) -> model_processor.ModelProcessor: + """Creates a processor of the given type with a subset of the given params.""" + # Clone the given parameters dict first. + params = dict(processor_params) + # Remove any parameters that are not in the processor's constructor signature. + sig = inspect.signature(processor_type.__init__) + if not any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()): + for missing in params.keys() - sig.parameters.keys(): + del params[missing] + # Finally, construct the concrete processor. + return processor_type(**params) diff --git a/schema/model_consumer_test.py b/schema/model_consumer_test.py new file mode 100644 index 000000000..6ec80c913 --- /dev/null +++ b/schema/model_consumer_test.py @@ -0,0 +1,187 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +import functools +from unittest import mock + +from absl.testing import absltest +from mmm.v1 import mmm_pb2 as mmm_pb +from mmm.v1.fit import model_fit_pb2 as fit_pb +from mmm.v1.marketing.optimization import marketing_optimization_pb2 as opt_pb +from mmm.v1.model import mmm_kernel_pb2 as kernel_pb +from schema import model_consumer +from schema.processors import model_processor +from schema.serde import meridian_serde + + +class FooSpec(model_processor.Spec): + + def validate(self): + pass + + +class FooProcessor(model_processor.ModelProcessor[FooSpec, fit_pb.ModelFit]): + + def __init__(self, trained_model: model_processor.TrainedModel): + self._trained_model = trained_model + + @classmethod + def spec_type(cls): + return FooSpec + + @classmethod + def output_type(cls): + return fit_pb.ModelFit + + def execute(self, specs: Sequence[FooSpec]) -> fit_pb.ModelFit: + return fit_pb.ModelFit() + + def _set_output(self, output: mmm_pb.Mmm, result: fit_pb.ModelFit): + output.model_fit.CopyFrom(result) + + +class BarSpec(model_processor.Spec): + + def validate(self): + pass + + +class BarProcessor( + model_processor.ModelProcessor[BarSpec, opt_pb.MarketingOptimization] +): + + def __init__( + self, + trained_model: model_processor.TrainedModel, + ): + self._trained_model = trained_model + + @classmethod + def spec_type(cls): + return BarSpec + + @classmethod + def output_type(cls): + return opt_pb.MarketingOptimization + + def execute(self, specs: Sequence[BarSpec]) -> opt_pb.MarketingOptimization: + return opt_pb.MarketingOptimization() + + def _set_output( + self, output: mmm_pb.Mmm, result: opt_pb.MarketingOptimization + ): + output.marketing_optimization.CopyFrom(result) + + +class ModelConsumerTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.mock_mmm = mock.MagicMock() + + # Patch the trained model class. + self.mock_trained_model = self.enter_context( + mock.patch.object(model_processor, 'TrainedModel', autospec=True) + )(mmm=self.mock_mmm) + self.consumer = model_consumer.ModelConsumer([FooProcessor, BarProcessor]) + self.enter_context( + mock.patch.object( + meridian_serde.MeridianSerde, + 'serialize', + autospec=True, + return_value=kernel_pb.MmmKernel(), + ) + ) + + def test_specs_to_processors(self): + self.assertEqual( + self.consumer.specs_to_processors_classes, + {FooSpec: FooProcessor, BarSpec: BarProcessor}, + ) + + def test_specs_to_processors_error_on_duplicate_spec_types(self): + class DuplicateProcessor( + model_processor.ModelProcessor[FooSpec, fit_pb.ModelFit] + ): + + @classmethod + def spec_type(cls): + return FooSpec + + @classmethod + def output_type(cls): + return fit_pb.ModelFit + + def execute(self, specs: Sequence[FooSpec]) -> fit_pb.ModelFit: + return fit_pb.ModelFit() + + def _set_output(self, output: mmm_pb.Mmm, result: fit_pb.ModelFit): + output.model_fit.CopyFrom(result) + + with self.assertRaises(ValueError): + _ = model_consumer.ModelConsumer( + [FooProcessor, BarProcessor, DuplicateProcessor] + ).specs_to_processors_classes + + def test_consumer_call_dispatches_to_processors(self): + # This context dict will be used to verify that the correct processors are + # called. + context = { + 'foo': False, # if FooProcessor.execute is called. + 'bar': False, # if BarProcessor.execute is called. + } + + def _patch_foo_execute( + slf, specs: Sequence[FooSpec], context + ) -> fit_pb.ModelFit: + self.assertLen(specs, 2) + self.assertTrue(all([isinstance(spec, FooSpec) for spec in specs])) + self.assertIs(slf._trained_model, self.mock_trained_model) + context['foo'] = True + return fit_pb.ModelFit() + + FooProcessor.execute = functools.partialmethod( + _patch_foo_execute, context=context + ) + + def _patch_bar_execute( + slf, specs: Sequence[BarSpec], context + ) -> opt_pb.MarketingOptimization: + self.assertLen(specs, 1) + self.assertTrue(all([isinstance(spec, BarSpec) for spec in specs])) + self.assertIs(slf._trained_model, self.mock_trained_model) + context['bar'] = True + return opt_pb.MarketingOptimization() + + BarProcessor.execute = functools.partialmethod( + _patch_bar_execute, context=context + ) + + self.consumer = model_consumer.ModelConsumer([FooProcessor, BarProcessor]) + + # Calling the model consumer should execute both FooProcessor and + # BarProcessor: FooProcessor should be given two FooSpecs and BarProcessor + # should be given one BarSpec. + output = self.consumer(self.mock_mmm, [FooSpec(), FooSpec(), BarSpec()]) + + self.assertTrue(output.HasField('mmm_kernel')) + self.assertTrue(output.HasField('model_fit')) + self.assertTrue(output.HasField('marketing_optimization')) + self.assertTrue(context['foo'], 'FooProcessor.execute was not called.') + self.assertTrue(context['bar'], 'BarProcessor.execute was not called.') + + +if __name__ == '__main__': + absltest.main() diff --git a/schema/processors/__init__.py b/schema/processors/__init__.py new file mode 100644 index 000000000..4157bf3ad --- /dev/null +++ b/schema/processors/__init__.py @@ -0,0 +1,77 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Meridian Model Processor Library. + +This package provides a collection of processors designed to operate on trained +Meridian models. These processors facilitate various post-training tasks, +including model analysis, insight generation, and budget optimization. + +The processors are built upon a common framework defined in the +`model_processor` module, which establishes the base classes and interfaces for +builtin processors in this package, as well as for creating custom processors. +Each processor typically takes a trained Meridian model object and additional +specifications as input, producing structured output, in protobuf format. + +These structured outputs can then be used to generate insights, visualizations, +and other artifacts that help users understand and optimize their marketing +strategy. For instance, the `schema.converters` package provides tools to +flatten these outputs into tabular Google Sheets tables suitable for a Meridian +Looker Studio dashboard's data sources. + +Available Processor Modules: + +- `model_processor`: Defines the abstract base classes `ModelProcessor` and + `ModelProcessorSpec`, which serve as the foundation for all processors + in this package. +- `model_kernel_processor`: A processor to extract and serialize the core + components and parameters of the trained Meridian model. +- `model_fit_processor`: Generates various goodness-of-fit statistics and + diagnostic metrics for the trained model. +- `marketing_processor`: Performs marketing mix analysis, including + contribution analysis, response curves, and ROI calculations. +- `budget_optimization_processor`: Provides tools for optimizing marketing + budgets based on the model's predictions to achieve specific goals. +- `reach_frequency_processor`: Analyzes and optimizes based on reach and + frequency metrics, if applicable to the model structure. + +Each processor defines its own spec language. For instance, the budget +optimization processor would take a `BudgetOptimizationSpec` object as input, +which defines the constraints and parameters of the optimization problem a +user wants to explore. + +A trained Meridian model is generally a requisite input for all processors. +Generally, a `model_processor.TrainedModel` wrapper object is passed to each +processor, along with its processor-specific spec. For example: + +```python +# Assuming 'trained_model' is a loaded Meridian model object +processor = model_fit_processor.ModelFitProcessor(trained_model) +result = processor([model_fit_processor.ModelFitSpec()]) + +# `result` is a structured `ModelFit` proto that describes the model's goodness +# of fit analysis. +``` + +For more details on these processors' sub-API, please refer to the documentation +of the individual modules. +""" + +from schema.processors import budget_optimization_processor +from schema.processors import common +from schema.processors import marketing_processor +from schema.processors import model_fit_processor +from schema.processors import model_kernel_processor +from schema.processors import model_processor +from schema.processors import reach_frequency_optimization_processor diff --git a/schema/processors/budget_optimization_processor.py b/schema/processors/budget_optimization_processor.py new file mode 100644 index 000000000..b225593e7 --- /dev/null +++ b/schema/processors/budget_optimization_processor.py @@ -0,0 +1,813 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines a processor for budget optimization inference on a Meridian model. + +This module provides the `BudgetOptimizationProcessor` class, which is used to +perform marketing budget optimization based on a trained Meridian model. The +processor takes a trained model and a `BudgetOptimizationSpec` object, +which defines the optimization parameters, constraints, and scenarios. + +The optimization process aims to find the optimal allocation of budget across +different media channels to maximize a specified objective, such as Key +Performance Indicator (KPI) or Revenue, subject to various constraints. + +Key Features: + +- Supports both fixed and flexible budget scenarios. +- Allows setting channel-level budget constraints, either as absolute values + or relative to historical spend. +- Generates detailed optimization results, including optimal spends, expected + outcomes, and response curves. +- Outputs results in a structured protobuf format (`BudgetOptimization`). + +Key Classes: + +- `BudgetOptimizationSpec`: Dataclass to specify optimization parameters and + constraints. +- `BudgetOptimizationProcessor`: The main processor class to execute budget + optimization. + +Example Usage: + +1. **Fixed Budget Optimization:** + Optimize budget allocation for a fixed total budget, aiming to maximize KPI. + + ```python + from schema.processors import budget_optimization_processor + from meridian.analysis import optimizer + from schema.processors import common + + # Assuming 'trained_model' is a loaded Meridian model object + + spec = budget_optimization_processor.BudgetOptimizationSpec( + optimization_name="fixed_budget_scenario_1", + scenario=optimizer.FixedBudgetScenario(total_budget=1000000), + kpi_type=common.KpiType.REVENUE, # Or common.KpiType.NON_REVENUE + # Optional: Add channel constraints + constraints=[ + budget_optimization_processor.ChannelConstraintRel( + channel_name="channel_a", + spend_constraint_lower=0.1, # Allow 10% decrease + spend_constraint_upper=0.5 # Allow 50% increase + ), + budget_optimization_processor.ChannelConstraintRel( + channel_name="channel_b", + spend_constraint_lower=0.0, # No decrease + spend_constraint_upper=1.0 # Allow 100% increase + ) + ], + include_response_curves=True, + ) + + processor = budget_optimization_processor.BudgetOptimizationProcessor( + trained_model + ) + # result is a `budget_pb.BudgetOptimization` proto + result = processor.execute([spec]) + ``` + +2. **Flexible Budget Optimization:** + Optimize budget to achieve a target Return on Investment (ROI). + + ```python + from schema.processors import budget_optimization_processor + from meridian.analysis import optimizer + from schema.processors import common + import meridian.constants as c + + # Assuming 'trained_model' is a loaded Meridian model object + + spec = budget_optimization_processor.BudgetOptimizationSpec( + optimization_name="flexible_roi_target", + scenario=optimizer.FlexibleBudgetScenario( + target_metric=c.ROI, + target_value=3.5 # Target ROI of 3.5 + ), + kpi_type=common.KpiType.REVENUE, + date_interval_tag="optimization_period", + # Skip response curves for faster computation. + include_response_curves=False, + ) + + processor = budget_optimization_processor.BudgetOptimizationProcessor( + trained_model + ) + result = processor.execute([spec]) + ``` + +Note: You can provide the processor with multiple specs. This would result in +a `BudgetOptimization` output with multiple results therein. +""" + +from collections.abc import Mapping, Sequence +import dataclasses +from typing import TypeAlias + +from meridian import constants as c +from meridian.analysis import analyzer +from meridian.analysis import optimizer +from meridian.data import time_coordinates as tc +from mmm.v1 import mmm_pb2 as pb +from mmm.v1.common import estimate_pb2 as estimate_pb +from mmm.v1.common import kpi_type_pb2 as kpi_type_pb +from mmm.v1.common import target_metric_pb2 as target_pb +from mmm.v1.marketing.analysis import marketing_analysis_pb2 as analysis_pb +from mmm.v1.marketing.analysis import media_analysis_pb2 as media_analysis_pb +from mmm.v1.marketing.analysis import outcome_pb2 as outcome_pb +from mmm.v1.marketing.analysis import response_curve_pb2 as response_curve_pb +from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb +from mmm.v1.marketing.optimization import constraints_pb2 as constraints_pb +from schema.processors import common +from schema.processors import model_processor +from schema.utils import time_record +import numpy as np +from typing_extensions import override +import xarray as xr + +__all__ = [ + 'BudgetOptimizationProcessor', + 'BudgetOptimizationSpec', + 'ChannelConstraintAbs', + 'ChannelConstraintRel', +] + + +# Default lower and upper bounds (as _relative_ ratios) for channel constraints. +CHANNEL_CONSTRAINT_LOWERBOUND_DEFAULT_RATIO = 1 +CHANNEL_CONSTRAINT_UPPERBOUND_DEFAULT_RATIO = 2 + + +@dataclasses.dataclass(frozen=True) +class ChannelConstraintAbs: + """A budget constraint on a channel. + + Constraint attributes in this dataclass are absolute values. Useful to + represent resolved absolute constraint values in an output spec metadata. + + Attributes: + channel_name: The name of the channel. + abs_lowerbound: A simple absolute lower bound value for a channel's spend. + abs_upperbound: A simple absolute upper bound value for a channel's spend. + """ + + channel_name: str + abs_lowerbound: float + abs_upperbound: float + + def to_proto(self) -> budget_pb.ChannelConstraint: + return budget_pb.ChannelConstraint( + channel_name=self.channel_name, + budget_constraint=constraints_pb.BudgetConstraint( + min_budget=self.abs_lowerbound, + max_budget=self.abs_upperbound, + ), + ) + + +@dataclasses.dataclass(frozen=True) +class ChannelConstraintRel: + """A budget constraint on a channel. + + Constraint attributes in this dataclass are relative ratios. Useful for user + input spec. + + Attributes: + channel_name: The name of the channel. + spend_constraint_lower: The spend constraint lower of a channel is the + change in ratio w.r.t. the channel's historical spend. The absolute lower + bound value is equal to `(1 - spend_constraint_lower) * + hist_channel_spend)`. The value must be between `[0, 1]`. + spend_constraint_upper: The spend constraint upper of a channel is the + change in ratio w.r.t. the channel's historical spend. The absolute upper + bound value is equal to `(1 + spend_constraint_upper) * + hist_channel_spend)`. The value must be non-negative. + """ + + channel_name: str + spend_constraint_lower: float + spend_constraint_upper: float + + def __post_init__(self): + if self.spend_constraint_lower < 0: + raise ValueError('Spend constraint lower must be non-negative.') + if self.spend_constraint_lower > 1: + raise ValueError('Spend constraint lower must not be greater than 1.') + if self.spend_constraint_upper < 0: + raise ValueError('Spend constraint upper must be non-negative.') + + +ChannelConstraint: TypeAlias = ChannelConstraintAbs | ChannelConstraintRel + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class BudgetOptimizationSpec(model_processor.OptimizationSpec): + """Spec dataclass for marketing budget optimization processor. + + This spec is used both as user input to inform the budget optimization + processor of its constraints and parameters, as well as an output structure + that is serializable to a `BudgetOptimizationSpec` proto. The latter serves + as a metadata embedded in a `BudgetOptimizationResult`. + + Attributes: + objective: Always defined as KPI. + scenario: The optimization scenario (whether fixed or flexible). + constraints: Per-channel budget constraints. Defaults to relative + constraints `[1, 2]` for spend_constraint_lower and spend_constraint_upper + if not specified. + kpi_type: A `common.KpiType` enum denoting whether the optimized KPI is of a + `'revenue'` or `'non-revenue'` type. + grid: The optimization grid to use for the optimization. If None, a new grid + will be created within the optimizer. + include_response_curves: Whether to include response curves in the output. + Setting this to `False` improves performance if only optimization result + is needed. + new_data: The new data to use for the optimization. If None, the training + data will be used. + """ + + scenario: optimizer.FixedBudgetScenario | optimizer.FlexibleBudgetScenario = ( + dataclasses.field(default_factory=optimizer.FixedBudgetScenario) + ) + constraints: Sequence[ChannelConstraint] = dataclasses.field( + default_factory=list + ) + kpi_type: common.KpiType = common.KpiType.REVENUE + grid: optimizer.OptimizationGrid | None = None + include_response_curves: bool = True + new_data: analyzer.DataTensors | None = None + + @property + def objective(self) -> common.TargetMetric: + """A Meridian budget optimization objective is always KPI.""" + return common.TargetMetric.KPI + + @override + def validate(self): + super().validate() + if (self.new_data is not None) and (self.new_data.time is None): + raise ValueError('`time` must be provided in `new_data`.') + + # TODO: Populate `new_marketing_data`. + def to_proto(self) -> budget_pb.BudgetOptimizationSpec: + # When invoked as an output proto, the spec should have been fully resolved. + if self.start_date is None or self.end_date is None: + raise ValueError( + 'Start and end dates must be resolved before this spec can be' + ' serialized.' + ) + + proto = budget_pb.BudgetOptimizationSpec( + date_interval=time_record.create_date_interval_pb( + self.start_date, self.end_date, tag=self.date_interval_tag + ), + objective=self.objective.value, + kpi_type=( + kpi_type_pb.KpiType.REVENUE + if self.kpi_type == common.KpiType.REVENUE + else kpi_type_pb.KpiType.NON_REVENUE + ), + ) + + match self.scenario: + case optimizer.FixedBudgetScenario(total_budget): + if total_budget is None: + raise ValueError( + 'Total budget must be resolved before this spec can be serialized' + ) + proto.fixed_budget_scenario.total_budget = total_budget + case optimizer.FlexibleBudgetScenario(target_metric, target_value): + proto.flexible_budget_scenario.target_metric_constraints.append( + constraints_pb.TargetMetricConstraint( + target_metric=_target_metric_to_proto(target_metric), + target_value=target_value, + ) + ) + case _: + raise ValueError('Unsupported scenario type.') + + for channel_constraint in self.constraints: + # When invoked as an output proto, the spec's constraints must have been + # resolved to absolute values. + if not isinstance(channel_constraint, ChannelConstraintAbs): + raise ValueError( + 'Channel constraints must be resolved to absolute values before' + ' this spec can be serialized.' + ) + + proto.channel_constraints.append( + budget_pb.ChannelConstraint( + channel_name=channel_constraint.channel_name, + budget_constraint=constraints_pb.BudgetConstraint( + min_budget=channel_constraint.abs_lowerbound, + max_budget=channel_constraint.abs_upperbound, + ), + ) + ) + + return proto + + +class BudgetOptimizationProcessor( + model_processor.ModelProcessor[ + BudgetOptimizationSpec, budget_pb.BudgetOptimization + ], +): + """A Processor for marketing budget optimization.""" + + def __init__( + self, + trained_model: model_processor.ModelType, + ): + self._trained_model = model_processor.ensure_trained_model(trained_model) + self._internal_analyzer = self._trained_model.internal_analyzer + self._internal_optimizer = self._trained_model.internal_optimizer + + @classmethod + def spec_type(cls) -> type[BudgetOptimizationSpec]: + return BudgetOptimizationSpec + + @classmethod + def output_type(cls) -> type[budget_pb.BudgetOptimization]: + return budget_pb.BudgetOptimization + + def _set_output(self, output: pb.Mmm, result: budget_pb.BudgetOptimization): + output.marketing_optimization.budget_optimization.CopyFrom(result) + + def execute( + self, specs: Sequence[BudgetOptimizationSpec] + ) -> budget_pb.BudgetOptimization: + output = budget_pb.BudgetOptimization() + + group_ids = [spec.group_id for spec in specs if spec.group_id] + if len(set(group_ids)) != len(group_ids): + raise ValueError( + 'Specified group_id must be unique among the given group of specs.' + ) + + # For each given spec: + # 1. Run optimize, which computes channel outcomes and their optimal spends. + # 2. Run _create_grids, which creates incremental spend outcome grids. + # 3. Compile the final BudgetOptimization proto. + for spec in specs: + kwargs = build_scenario_kwargs(spec.scenario) + constraints_kwargs = build_constraints_kwargs( + spec.constraints, + self._trained_model.mmm.input_data.get_all_paid_channels(), + ) + kwargs.update(constraints_kwargs) + if spec.new_data is not None and spec.new_data.time is not None: + time_coords = tc.TimeCoordinates.from_dates( + [s.decode() for s in np.asarray(spec.new_data.time)] + ) + else: + time_coords = self._trained_model.time_coordinates + resolver = spec.resolver(time_coords) + start_date, end_date = resolver.to_closed_date_interval_tuple() + + # Note that `optimize()` maximises KPI if the input data is non-revenue + # and the user selected `use_kpi=True`. Otherwise, it maximizes revenue. + opt_result = self._internal_optimizer.optimize( + start_date=start_date, + end_date=end_date, + fixed_budget=isinstance(spec.scenario, optimizer.FixedBudgetScenario), + confidence_level=spec.confidence_level, + use_kpi=(spec.kpi_type == common.KpiType.NON_REVENUE), + optimization_grid=spec.grid, + new_data=spec.new_data, + **kwargs, + ) + + output.results.append( + self._to_budget_optimization_result( + spec, opt_result, resolver, **constraints_kwargs + ) + ) + + return output + + def _to_budget_optimization_result( + self, + spec: BudgetOptimizationSpec, + opt_result: optimizer.OptimizationResults, + resolver: model_processor.DatedSpecResolver, + spend_constraint_lower: Sequence[float], + spend_constraint_upper: Sequence[float], + ) -> budget_pb.BudgetOptimizationResult: + """Converts an optimizer result to a BudgetOptimizationResult proto. + + Args: + spec: The spec used to generate the oiptimization result.. + opt_result: The result of the optimization. + resolver: A DatedSpecResolver instance. + spend_constraint_lower: A sequence of lower bound constraints for each + channel, in relative terms. + spend_constraint_upper: A sequence of upper bound constraints for each + channel, in relative terms. + + Returns: + A BudgetOptimizationResult proto. + """ + # Copy the current spec, and resolve its date interval. + start, end = resolver.resolve_to_date_interval_open_end() + + # Resolve the given (input) spec to an (output) spec: the latter features + # dates and absolute channel constraints resolution. + spec = dataclasses.replace( + spec, + start_date=start, + end_date=end, + constraints=_get_channel_constraints_abs( + opt_result=opt_result, + constraint_lower=spend_constraint_lower, + constraint_upper=spend_constraint_upper, + ), + ) + + # If the spec is a fixed budget scenario, but the total budget is not + # specified, then set it to the budget amount used in the optimization. + resolve_historical_budget = ( + isinstance(spec.scenario, optimizer.FixedBudgetScenario) + and spec.scenario.total_budget is None + ) + if resolve_historical_budget: + spec = dataclasses.replace( + spec, + scenario=optimizer.FixedBudgetScenario( + total_budget=opt_result.optimized_data.attrs[c.BUDGET] + ), + ) + + xr_response_curves = ( + opt_result.get_response_curves() + if spec.include_response_curves + else None + ) + optimized_marketing_analysis = to_marketing_analysis( + spec=spec, + xr_data=opt_result.optimized_data, + xr_response_curves=xr_response_curves, + ) + nonoptimized_marketing_analysis = to_marketing_analysis( + spec=spec, + xr_data=opt_result.nonoptimized_data, + xr_response_curves=xr_response_curves, + ) + result = budget_pb.BudgetOptimizationResult( + name=spec.optimization_name, + spec=spec.to_proto(), + optimized_marketing_analysis=optimized_marketing_analysis, + nonoptimized_marketing_analysis=nonoptimized_marketing_analysis, + incremental_outcome_grid=_to_incremental_outcome_grid( + opt_result.optimization_grid.grid_dataset, + grid_name=spec.grid_name, + ), + ) + + if spec.group_id: + result.group_id = spec.group_id + return result + + +def to_marketing_analysis( + spec: model_processor.DatedSpec, + xr_data: xr.Dataset, + xr_response_curves: xr.Dataset | None, +) -> analysis_pb.MarketingAnalysis: + """Converts OptimizationResults to MarketingAnalysis protos. + + Args: + spec: The spec to build MarketingAnalysis protos for. + xr_data: The xr.Dataset to convert into MarketingAnalysis proto. + xr_response_curves: The xr.Dataset to convert into response curves. + + Returns: + A MarketingAnalysis proto. + """ + # `spec` should have been resolved with concrete date interval parameters. + assert spec.start_date is not None and spec.end_date is not None + marketing_analysis = analysis_pb.MarketingAnalysis( + date_interval=time_record.create_date_interval_pb( + start_date=spec.start_date, + end_date=spec.end_date, + tag=spec.date_interval_tag, + ), + ) + # Include the response curves data for all channels at the optimized freq. + channel_response_curve_protos = _to_channel_response_curve_protos( + xr_response_curves + ) + + # Create a per-channel MediaAnalysis. + for channel in xr_data.channel.values: + channel_data = xr_data.sel(channel=channel) + spend = channel_data.spend.item() + # TODO: Resolve conflict definition of spend share. + spend_share = channel_data.pct_of_spend.item() + channel_media_analysis = media_analysis_pb.MediaAnalysis( + channel_name=channel, + spend_info=media_analysis_pb.SpendInfo( + spend=spend, + spend_share=spend_share, + ), + ) + # Output one outcome per channel: either revenue or non-revenue, + # but not both. + channel_media_analysis.media_outcomes.append(_to_outcome(channel_data)) + if xr_response_curves is not None: + channel_media_analysis.response_curve.CopyFrom( + channel_response_curve_protos[channel] + ) + marketing_analysis.media_analyses.append(channel_media_analysis) + + return marketing_analysis + + +def _get_channel_constraints_abs( + opt_result: optimizer.OptimizationResults, + constraint_lower: Sequence[float], + constraint_upper: Sequence[float], +) -> list[ChannelConstraintAbs]: + """Converts a sequence of channel constraints in relative terms to absolute ones. + + Args: + opt_result: The optimization result. + constraint_lower: A sequence of lower bound constraints for each channel, in + relative terms. + constraint_upper: A sequence of upper bound constraints for each channel, in + relative terms. + + Returns: + A list of channel constraints in absolute terms. + """ + round_factor = opt_result.optimization_grid.round_factor + channels = opt_result.optimized_data.channel.values + (optimization_lower_bound, optimization_upper_bound) = ( + optimizer.get_optimization_bounds( + n_channels=len(channels), + spend=opt_result.nonoptimized_data.spend.data, + round_factor=round_factor, + spend_constraint_lower=constraint_lower, + spend_constraint_upper=constraint_upper, + ) + ) + + abs_constraints: list[ChannelConstraintAbs] = [] + for i, channel in enumerate(channels): + constraint = ChannelConstraintAbs( + channel_name=channel, + abs_lowerbound=optimization_lower_bound[i], + abs_upperbound=optimization_upper_bound[i], + ) + abs_constraints.append(constraint) + return abs_constraints + + +def build_scenario_kwargs( + scenario: optimizer.FixedBudgetScenario | optimizer.FlexibleBudgetScenario, +) -> dict[str, float]: + """Returns keyword arguments for an optimizer, given a spec's scenario. + + The keys in the returned kwargs are a subset of the parameters in + `optimizer.BudgetOptimizer.optimize()` method. + + Args: + scenario: The scenario to build kwargs for. + + Raises: + ValueError: If no scenario is specified in the spec, or if for a given + scenario type, its values are invalid. + """ + kwargs = {} + match scenario: + case optimizer.FixedBudgetScenario(total_budget): + if total_budget is not None: # if not specified => historical spend + kwargs['budget'] = total_budget + case optimizer.FlexibleBudgetScenario(target_metric, target_value): + match target_metric: + case c.ROI: + key = 'target_roi' + case c.MROI: + key = 'target_mroi' + case _: + # Technically dead code, since this is already checked in `validate()` + raise ValueError( + f'Unsupported target metric: {target_metric} for flexible' + ' budget scenario.' + ) + kwargs[key] = target_value + case _: + # Technically dead code. + raise ValueError('Unsupported scenario type.') + return kwargs + + +def build_constraints_kwargs( + constraints: Sequence[ChannelConstraint], + model_channels: Sequence[str], +) -> dict[str, list[float]]: + """Returns `spend_constraint_**` kwargs for given channel constraints. + + If a media channel is not present in the spec's channel constraints, then + its spend constraint is implied to be the max budget of the spec's scenario. + + Args: + constraints: The channel constraints from the spec. + model_channels: The list of channels in the model. + + Raises: + ValueError: If the channel constraints are invalid (e.g. channel names are + not matched with the internal model data, etc). + """ + # Validate user-configured channel constraints in the spec. + constraints_by_channel_name = {c.channel_name: c for c in constraints} + constraint_channel_names = set(constraints_by_channel_name.keys()) + if not (constraint_channel_names <= set(model_channels)): + raise ValueError( + 'Channel constraints must have channel names that are in the model' + f' data. Expected {model_channels}, got {constraint_channel_names}.' + ) + + spend_constraint_lower = [] + spend_constraint_upper = [] + for channel in model_channels: + if channel in constraints_by_channel_name: + constraint = constraints_by_channel_name[channel] + if not isinstance(constraint, ChannelConstraintRel): + raise ValueError( + 'Channel constraints in user input must be expressed in relative' + ' ratio terms.' + ) + lowerbound = constraint.spend_constraint_lower + upperbound = constraint.spend_constraint_upper + else: + lowerbound = CHANNEL_CONSTRAINT_LOWERBOUND_DEFAULT_RATIO + upperbound = CHANNEL_CONSTRAINT_UPPERBOUND_DEFAULT_RATIO + + spend_constraint_lower.append(lowerbound) + spend_constraint_upper.append(upperbound) + + return { + 'spend_constraint_lower': spend_constraint_lower, + 'spend_constraint_upper': spend_constraint_upper, + } + + +def _to_channel_response_curve_protos( + optimized_response_curves: xr.Dataset | None, +) -> Mapping[str, response_curve_pb.ResponseCurve]: + """Converts a response curve dataframe to a map of channel to ResponseCurve. + + Args: + optimized_response_curves: A dataframe containing the response curve data. + This is the output of `OptimizationResults.get_response_curves()`. + + Returns: + A map of channel to ResponseCurve proto. + """ + if optimized_response_curves is None: + return {} + channels = optimized_response_curves.channel.values + # Flatten the dataset into a tabular dataframe so we can iterate over it. + df = ( + optimized_response_curves.to_dataframe() + .reset_index() + .pivot( + index=[c.CHANNEL, c.SPEND, c.SPEND_MULTIPLIER], + columns=c.METRIC, + values=c.INCREMENTAL_OUTCOME, + ) + .reset_index() + ).sort_values(by=[c.CHANNEL, c.SPEND]) + + channel_response_curves = { + channel: response_curve_pb.ResponseCurve(input_name=c.SPEND) + for channel in channels + } + + for _, row in df.iterrows(): + channel = row[c.CHANNEL] + response_point = response_curve_pb.ResponsePoint( + input_value=row[c.SPEND], + incremental_kpi=row[c.MEAN], + ) + channel_response_curves[channel].response_points.append(response_point) + + return channel_response_curves + + +def _to_outcome(channel_data: xr.Dataset) -> outcome_pb.Outcome: + """Returns an Outcome value for a given channel's media analysis. + + Args: + channel_data: A channel-selected dataset from `OptimizationResults`. + """ + confidence_level = channel_data.attrs[c.CONFIDENCE_LEVEL] + is_revenue_kpi = channel_data.attrs[c.IS_REVENUE_KPI] + + return outcome_pb.Outcome( + kpi_type=( + kpi_type_pb.REVENUE if is_revenue_kpi else kpi_type_pb.NON_REVENUE + ), + roi=_to_estimate(channel_data.roi, confidence_level), + marginal_roi=_to_estimate(channel_data.mroi, confidence_level), + cost_per_contribution=_to_estimate( + channel_data.cpik, + confidence_level=confidence_level, + ), + contribution=outcome_pb.Contribution( + value=_to_estimate( + channel_data.incremental_outcome, confidence_level + ), + ), + effectiveness=outcome_pb.Effectiveness( + media_unit=c.IMPRESSIONS, + value=_to_estimate(channel_data.effectiveness, confidence_level), + ), + ) + + +def _to_incremental_outcome_grid( + optimization_grid: xr.Dataset, + grid_name: str | None, +) -> budget_pb.IncrementalOutcomeGrid: + """Converts an optimization grid to an `IncrementalOutcomeGrid` proto. + + Args: + optimization_grid: The optimization grid dataset in + `OptimizationResults.optimization_grid`. + grid_name: A user-given name for this grid. + + Returns: + An `IncrementalOutcomeGrid` proto. + """ + grid = budget_pb.IncrementalOutcomeGrid( + name=(grid_name or ''), + spend_step_size=optimization_grid.spend_step_size, + ) + for channel in optimization_grid.channel.values: + channel_grid = optimization_grid.sel(channel=channel) + spend_grid = channel_grid.spend_grid.dropna(dim=c.GRID_SPEND_INDEX) + incremental_outcome_grid = channel_grid.incremental_outcome_grid.dropna( + dim=c.GRID_SPEND_INDEX + ) + if len(spend_grid) != len(incremental_outcome_grid): + raise ValueError( + f'Spend grid and incremental outcome grid for channel "{channel}" do' + ' not agree.' + ) + channel_cells = budget_pb.IncrementalOutcomeGrid.ChannelCells( + channel_name=channel, + cells=[ + budget_pb.IncrementalOutcomeGrid.Cell( + spend=spend.item(), + incremental_outcome=estimate_pb.Estimate( + value=incr_outcome.item() + ), + ) + for (spend, incr_outcome) in zip( + spend_grid, incremental_outcome_grid + ) + ], + ) + grid.channel_cells.append(channel_cells) + return grid + + +def _to_estimate( + dataarray: xr.DataArray, + confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL, +) -> estimate_pb.Estimate: + """Converts a DataArray with (mean, ci_lo, ci_hi) `metric` datavars.""" + estimate = estimate_pb.Estimate(value=dataarray.sel(metric=c.MEAN).item()) + uncertainty = estimate_pb.Estimate.Uncertainty( + probability=confidence_level, + lowerbound=dataarray.sel(metric=c.CI_LO).item(), + upperbound=dataarray.sel(metric=c.CI_HI).item(), + ) + estimate.uncertainties.append(uncertainty) + return estimate + + +def _target_metric_to_proto( + target_metric: str, +) -> target_pb.TargetMetric: + """Converts a TargetMetric enum to a TargetMetric proto.""" + match target_metric: + case c.ROI: + return target_pb.TargetMetric.ROI + case c.MROI: + return target_pb.TargetMetric.MARGINAL_ROI + case _: + raise ValueError(f'Unsupported target metric: {target_metric}') diff --git a/schema/processors/common.py b/schema/processors/common.py new file mode 100644 index 000000000..d738b980e --- /dev/null +++ b/schema/processors/common.py @@ -0,0 +1,64 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes and functions common to modules in this directory.""" + +import enum + +from meridian import constants +from mmm.v1.common import estimate_pb2 as estimate_pb +from mmm.v1.common import kpi_type_pb2 as kpi_type_pb +from mmm.v1.common import target_metric_pb2 as target_pb +import xarray as xr + + +__all__ = [ + "TargetMetric", + "KpiType", + "to_estimate", +] + + +class TargetMetric(enum.Enum): + KPI = target_pb.TargetMetric.KPI + ROI = target_pb.TargetMetric.ROI + MARGINAL_ROI = target_pb.TargetMetric.MARGINAL_ROI + + +@enum.unique +class KpiType(enum.Enum): + """Enum for KPI type used in analysis and optimization.""" + + REVENUE = kpi_type_pb.KpiType.REVENUE + NON_REVENUE = kpi_type_pb.KpiType.NON_REVENUE + + +def to_estimate( + dataarray: xr.DataArray, + confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL, + metric: str = constants.MEAN, +) -> estimate_pb.Estimate: + """Converts a DataArray with (mean [or median for CPIK], ci_lo, ci_hi) `metric` data vars.""" + value = dataarray.sel(metric=metric).item() + estimate = estimate_pb.Estimate( + value=value + ) + uncertainty = estimate_pb.Estimate.Uncertainty( + probability=confidence_level, + lowerbound=dataarray.sel(metric=constants.CI_LO).item(), + upperbound=dataarray.sel(metric=constants.CI_HI).item(), + ) + estimate.uncertainties.append(uncertainty) + + return estimate diff --git a/schema/processors/common_test.py b/schema/processors/common_test.py new file mode 100644 index 000000000..0b43b4cb7 --- /dev/null +++ b/schema/processors/common_test.py @@ -0,0 +1,57 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from meridian import constants +from mmm.v1.common import estimate_pb2 as estimate_pb +from schema.processors import common +import numpy as np +import xarray as xr + +from tensorflow.python.util.protobuf import compare + + +class CommonTest(absltest.TestCase): + + def test_to_estimate_returns_correct_estimate_proto(self): + data = xr.DataArray( + data=np.array([100.0, 90.0, 110.0]), + dims=[constants.METRIC], + coords={ + constants.METRIC: [ + constants.MEAN, + constants.CI_LO, + constants.CI_HI, + ] + }, + ) + + estimate_proto = common.to_estimate( + dataarray=data, confidence_level=constants.DEFAULT_CONFIDENCE_LEVEL + ) + expected_estimate_proto = estimate_pb.Estimate( + value=100.0, + uncertainties=[ + estimate_pb.Estimate.Uncertainty( + probability=constants.DEFAULT_CONFIDENCE_LEVEL, + lowerbound=90.0, + upperbound=110.0, + ) + ], + ) + compare.assertProtoEqual(self, estimate_proto, expected_estimate_proto) + + +if __name__ == "__main__": + absltest.main() diff --git a/schema/processors/marketing_processor.py b/schema/processors/marketing_processor.py new file mode 100644 index 000000000..d167653e5 --- /dev/null +++ b/schema/processors/marketing_processor.py @@ -0,0 +1,1136 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Meridian module for analyzing marketing data in a Meridian model. + +This module provides a `MarketingProcessor`, designed to extract key marketing +insights from a trained Meridian model. It allows users to understand the impact +of different marketing channels, calculate return on investment (ROI), and +generate response curves. + +The processor uses specifications defined in `MarketingAnalysisSpec` to control +the analysis. Users can request: + +1. **Media Summary Metrics:** Aggregated performance metrics for each media + channel, including spend, contribution, ROI, and effectiveness. +2. **Incremental Outcomes:** The additional KPI or revenue driven by marketing + activities, calculated by comparing against a baseline scenario (e.g., zero + spend). +3. **Response Curves:** Visualizations of how the predicted KPI or revenue + changes as spend on a particular channel increases, helping to identify + diminishing returns. + +The results are output as a `MarketingAnalysisList` protobuf message, containing +detailed breakdowns per channel and for the baseline. + +Key Classes: + +- `MediaSummarySpec`: Configures the calculation of summary metrics like ROI. +- `IncrementalOutcomeSpec`: Configures the calculation of incremental impact. +- `ResponseCurveSpec`: Configures response curve generation. +- `MarketingAnalysisSpec`: The main specification to combine the above, + define date ranges, and set confidence levels. +- `MarketingProcessor`: The processor class that executes the analysis based + on the provided specs. + +Example Usage: + +1. **Get Media Summary Metrics for a specific period:** + + ```python + from schema.processors import marketing_processor + import datetime + + # Assuming 'trained_model' is a loaded Meridian model object + + spec = marketing_processor.MarketingAnalysisSpec( + analysis_name="q1_summary", + start_date=datetime.date(2023, 1, 1), + end_date=datetime.date(2023, 3, 31), + media_summary_spec=marketing_processor.MediaSummarySpec( + aggregate_times=True + ), + response_curve_spec=marketing_processor.ResponseCurveSpec(), + confidence_level=0.9, + ) + + processor = marketing_processor.MarketingProcessor(trained_model) + # `result` is a `marketing_analysis_pb2.MarketingAnalysisList` proto + result = processor.execute([spec]) + ``` + +2. **Calculate Incremental Outcome with new spend data:** + + ```python + from schema.processors import marketing_processor + from meridian.analysis import analyzer + import datetime + import numpy as np + + # Assuming 'trained_model' is a loaded Meridian model object + # Assuming 'new_media_spend' is a numpy array with shape (time, channels) + + # Create DataTensors for the new data + # Example: + # new_data = analyzer.DataTensors( + # media=new_media_spend, + # time=new_time_index, + # ) + + spec = marketing_processor.MarketingAnalysisSpec( + analysis_name="what_if_scenario", + # NOTE: Dates must align with `new_data.time` + start_date=datetime.date(2023, 1, 1), + end_date=datetime.date(2023, 1, 31), + incremental_outcome_spec=marketing_processor.IncrementalOutcomeSpec( + new_data=new_data, + aggregate_times=True, + ), + ) + + processor = marketing_processor.MarketingProcessor(trained_model) + result = processor.execute([spec]) + + print(f"Incremental Outcome for {spec.analysis_name}:") + # Process results from result.marketing_analyses + ``` + +Note: You can provide the processor with multiple specs. This would result in +multiple marketing analysis results in the output. +""" + +from collections.abc import Sequence +import dataclasses +import datetime +import functools +import warnings + +from meridian import constants +from meridian.analysis import analyzer +from meridian.data import time_coordinates +from mmm.v1 import mmm_pb2 +from mmm.v1.common import date_interval_pb2 +from mmm.v1.common import kpi_type_pb2 +from mmm.v1.marketing.analysis import marketing_analysis_pb2 +from mmm.v1.marketing.analysis import media_analysis_pb2 +from mmm.v1.marketing.analysis import non_media_analysis_pb2 +from mmm.v1.marketing.analysis import outcome_pb2 +from mmm.v1.marketing.analysis import response_curve_pb2 +from schema.processors import common +from schema.processors import model_processor +import numpy as np +import xarray as xr + +__all__ = [ + "MediaSummarySpec", + "IncrementalOutcomeSpec", + "ResponseCurveSpec", + "MarketingAnalysisSpec", + "MarketingProcessor", +] + + +@dataclasses.dataclass(frozen=True) +class MediaSummarySpec(model_processor.Spec): + """Stores parameters needed for creating media summary metrics. + + Attributes: + aggregate_times: Boolean. If `True`, the media summary metrics are + aggregated over time. Defaults to `True`. + marginal_roi_by_reach: Boolean. Marginal ROI (mROI) is defined as the return + on the next dollar spent. If this argument is `True`, the assumption is + that the next dollar spent only impacts reach, holding frequency constant. + If this argument is `False`, the assumption is that the next dollar spent + only impacts frequency, holding reach constant. Defaults to `True`. + include_non_paid_channels: Boolean. If `True`, the media summary metrics + include non-paid channels. Defaults to `False`. + new_data: Optional `DataTensors` container with optional tensors: `media`, + `reach`, `frequency`, `organic_media`, `organic_reach`, + `organic_frequency`, `non_media_treatments` and `revenue_per_kpi`. If + `None`, the metrics are calculated using the `InputData` provided to the + Meridian object. If `new_data` is provided, the metrics are calculated + using the new tensors in `new_data` and the original values of the + remaining tensors. + media_selected_times: Optional list containing booleans with length equal to + the number of time periods in `new_data`, if provided. If `new_data` is + provided, `media_selected_times` can select any subset of time periods in + `new_data`. If `new_data` is not provided, `media_selected_times` selects + from model's original media data. + """ + + aggregate_times: bool = True + marginal_roi_by_reach: bool = True + include_non_paid_channels: bool = False + # b/384034128 Use new args in `summary_metrics`. + new_data: analyzer.DataTensors | None = None + media_selected_times: Sequence[bool] | None = None + + def validate(self): + pass + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class IncrementalOutcomeSpec(model_processor.Spec): + """Stores parameters needed for processing a model into `MarketingAnalysis`s. + + Attributes: + aggregate_times: Boolean. If `True`, the media summary metrics are + aggregated over time. Defaults to `True`. + new_data: Optional `DataTensors` container with optional tensors: `media`, + `reach`, `frequency`, `organic_media`, `organic_reach`, + `organic_frequency`, `non_media_treatments` and `revenue_per_kpi`. If + `None`, the incremental outcome is calculated using the `InputData` + provided to the Meridian object. If `new_data` is provided, the + incremental outcome is calculated using the new tensors in `new_data` and + the original values of the remaining tensors. For example, + `incremental_outcome(new_data=DataTensors(media=new_media)` computes the + incremental outcome using `new_media` and the original values of `reach`, + `frequency`, `organic_media`, `organic_reach`, `organic_frequency`, + `non_media_treatments` and `revenue_per_kpi`. If any of the tensors in + `new_data` is provided with a different number of time periods than in + `InputData`, then all tensors must be provided with the same number of + time periods. + media_selected_times: Optional list containing booleans with length equal to + the number of time periods in `new_data`, if provided. If `new_data` is + provided, `media_selected_times` can select any subset of time periods in + `new_data`. If `new_data` is not provided, `media_selected_times` selects + from model's original media data and its length must be equal to the + number of time periods in the model's original media data. + include_non_paid_channels: Boolean. If `True`, the incremental outcome + includes non-paid channels. Defaults to `False`. + """ + + aggregate_times: bool = True + new_data: analyzer.DataTensors | None = None + media_selected_times: Sequence[bool] | None = None + include_non_paid_channels: bool = False + + def validate(self): + super().validate() + if (self.new_data is not None) and (self.new_data.time is None): + raise ValueError("`time` must be provided in `new_data`.") + + +@dataclasses.dataclass(frozen=True) +class ResponseCurveSpec(model_processor.Spec): + """Stores parameters needed for creating response curves. + + Attributes: + by_reach: Boolean. For channels with reach and frequency. If `True`, plots + the response curve by reach. If `False`, plots the response curve by + frequency. + """ + + by_reach: bool = True + + def validate(self): + pass + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class MarketingAnalysisSpec(model_processor.DatedSpec): + """Stores parameters needed for processing a model into `MarketingAnalysis`s. + + Either `media_summary_spec` or `incremental_outcome_spec` must be provided, + but not both. + + Attributes: + media_summary_spec: Parameters for creating media summary metrics. Mutually + exclusive with `incremental_outcome_spec`. + incremental_outcome_spec: Parameters for creating incremental outcome. + Mutually exclusive with `media_summary_spec`. If `new_data` is provided, + then the start and end dates of this `MarketingAnalysisSpec` must be + within the `new_data.time`. + response_curve_spec: Parameters for creating response curves. Response + curves are only computed for specs that aggregate times and have a + `media_summary_spec` selected. + confidence_level: Confidence level for credible intervals, represented as a + value between zero and one. Defaults to 0.9. + """ + + media_summary_spec: MediaSummarySpec | None = None + incremental_outcome_spec: IncrementalOutcomeSpec | None = None + response_curve_spec: ResponseCurveSpec = dataclasses.field( + default_factory=ResponseCurveSpec + ) + confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL + + def validate(self): + super().validate() + if self.confidence_level <= 0 or self.confidence_level >= 1: + raise ValueError( + "Confidence level must be greater than 0 and less than 1." + ) + if ( + self.media_summary_spec is None + and self.incremental_outcome_spec is None + ): + raise ValueError( + "At least one of `media_summary_spec` or `incremental_outcome_spec`" + " must be provided." + ) + if ( + self.media_summary_spec is not None + and self.incremental_outcome_spec is not None + ): + raise ValueError( + "Only one of `media_summary_spec` or `incremental_outcome_spec` can" + " be provided." + ) + + +class MarketingProcessor( + model_processor.ModelProcessor[ + MarketingAnalysisSpec, marketing_analysis_pb2.MarketingAnalysisList + ] +): + """Generates `MarketingAnalysis` protos for a given trained Meridian model. + + A `MarketingAnalysis` proto is generated for each spec supplied to + `execute()`. Within each `MarketingAnalysis` proto, a `MediaAnalysis` proto + is created for each channel in the model. One `NonMediaAnalysis` proto is also + created for the model's baseline data. + """ + + def __init__( + self, + trained_model: model_processor.ModelType, + ): + trained_model = model_processor.ensure_trained_model(trained_model) + self._analyzer = trained_model.internal_analyzer + self._meridian = trained_model.mmm + self._model_time_coordinates = trained_model.time_coordinates + self._interval_length = self._model_time_coordinates.interval_days + + # If the input data KPI type is "revenue", then the `revenue_per_kpi` tensor + # must exist, and general-KPI type outcomes should not be defined. + self._revenue_kpi_type = ( + trained_model.mmm.input_data.kpi_type == constants.REVENUE + ) + # `_kpi_only` is TRUE iff the input data KPI type is "non-revenue" AND the + # `revenue_per_kpi` tensor is None. + self._kpi_only = trained_model.mmm.input_data.revenue_per_kpi is None + + @classmethod + def spec_type(cls) -> type[MarketingAnalysisSpec]: + return MarketingAnalysisSpec + + @classmethod + def output_type(cls) -> type[marketing_analysis_pb2.MarketingAnalysisList]: + return marketing_analysis_pb2.MarketingAnalysisList + + def _set_output( + self, + output: mmm_pb2.Mmm, + result: marketing_analysis_pb2.MarketingAnalysisList, + ): + output.marketing_analysis_list.CopyFrom(result) + + def execute( + self, marketing_analysis_specs: Sequence[MarketingAnalysisSpec] + ) -> marketing_analysis_pb2.MarketingAnalysisList: + """Runs a marketing analysis on the model based on the given specs. + + A `MarketingAnalysis` proto is created for each of the given specs. Each + `MarketingAnalysis` proto contains a list of `MediaAnalysis` protos and a + singleton `NonMediaAnalysis` proto for the baseline analysis. The analysis + covers the time period bounded by the spec's start and end dates. + + The singleton non-media analysis is performed on the model's baseline data, + and contains metrics such as incremental outcome and baseline percent of + contribution across media and non-media. + + A media analysis is performed for each channel in the model, plus an + "All Channels" synthetic channel. The media analysis contains metrics such + as spend, percent of spend, incremental outcome, percent of contribution, + and effectiveness. Depending on the type of data (revenue-based or + non-revenue-based) in the model, the analysis also contains CPIK + (non-revenue-based) or ROI and MROI (revenue-based). + + Args: + marketing_analysis_specs: A sequence of MarketingAnalysisSpec objects. + + Returns: + A MarketingAnalysisList proto containing the results of the marketing + analysis for each spec. + """ + marketing_analysis_list: list[marketing_analysis_pb2.MarketingAnalysis] = [] + + for spec in marketing_analysis_specs: + if ( + spec.incremental_outcome_spec is not None + and spec.incremental_outcome_spec.new_data is not None + and spec.incremental_outcome_spec.new_data.time is not None + ): + new_time_coords = time_coordinates.TimeCoordinates.from_dates( + np.asarray(spec.incremental_outcome_spec.new_data.time) + .astype(str) + .tolist() + ) + resolver = spec.resolver(new_time_coords) + else: + resolver = spec.resolver(self._model_time_coordinates) + media_summary_marketing_analyses = ( + self._generate_marketing_analyses_for_media_summary_spec( + spec, resolver + ) + ) + incremental_outcome_marketing_analyses = ( + self._generate_marketing_analyses_for_incremental_outcome_spec( + spec, resolver + ) + ) + marketing_analysis_list.extend( + media_summary_marketing_analyses + + incremental_outcome_marketing_analyses + ) + + return marketing_analysis_pb2.MarketingAnalysisList( + marketing_analyses=marketing_analysis_list + ) + + def _generate_marketing_analyses_for_media_summary_spec( + self, + marketing_analysis_spec: MarketingAnalysisSpec, + resolver: model_processor.DatedSpecResolver, + ) -> list[marketing_analysis_pb2.MarketingAnalysis]: + """Creates a list of MarketingAnalysis protos based on the given spec. + + If spec's `aggregate_times` is True, then only one MarketingAnalysis proto + is created. Otherwise, one MarketingAnalysis proto is created for each date + interval in the spec. + + Args: + marketing_analysis_spec: An instance of MarketingAnalysisSpec. + resolver: A DatedSpecResolver instance. + + Returns: + A list of `MarketingAnalysis` protos containing the results of the + marketing analysis for the given spec. + """ + media_summary_spec = marketing_analysis_spec.media_summary_spec + if media_summary_spec is None: + return [] + + selected_times = resolver.resolve_to_enumerated_selected_times() + # This contains either a revenue-based KPI or a non-revenue KPI analysis. + media_summary_metrics, non_media_summary_metrics = ( + self._generate_media_and_non_media_summary_metrics( + media_summary_spec, + selected_times, + marketing_analysis_spec.confidence_level, + self._kpi_only, + ) + ) + + secondary_non_revenue_kpi_metrics = None + secondary_non_revenue_kpi_non_media_metrics = None + # If the input data KPI type is "non-revenue", and we calculated its + # revenue-based KPI outcomes above, then we should also compute its + # non-revenue KPI outcomes. + if not self._revenue_kpi_type and not self._kpi_only: + ( + secondary_non_revenue_kpi_metrics, + secondary_non_revenue_kpi_non_media_metrics, + ) = self._generate_media_and_non_media_summary_metrics( + media_summary_spec, + selected_times, + marketing_analysis_spec.confidence_level, + use_kpi=True, + ) + + # Note: baseline_summary_metrics() prefers computing revenue (scaled from + # generic KPI with `revenue_per_kpi` when defined) baseline outcome here. + # TODO: Baseline outcomes for both revenue and non-revenue + # KPI types should be computed, when possible. + baseline_outcome = self._analyzer.baseline_summary_metrics( + confidence_level=marketing_analysis_spec.confidence_level, + aggregate_times=media_summary_spec.aggregate_times, + selected_times=selected_times, + ).sel(distribution=constants.POSTERIOR) + + # Response curves are only computed for specs that aggregate times. + if media_summary_spec.aggregate_times: + response_curve_spec = marketing_analysis_spec.response_curve_spec + response_curves = self._analyzer.response_curves( + confidence_level=marketing_analysis_spec.confidence_level, + use_posterior=True, + selected_times=selected_times, + use_kpi=self._kpi_only, + by_reach=response_curve_spec.by_reach, + ) + else: + response_curves = None + warnings.warn( + "Response curves are not computed for non-aggregated time periods." + ) + + date_intervals = self._build_time_intervals( + aggregate_times=media_summary_spec.aggregate_times, + resolver=resolver, + ) + + return self._marketing_metrics_to_protos( + media_summary_metrics, + non_media_summary_metrics, + baseline_outcome, + secondary_non_revenue_kpi_metrics, + secondary_non_revenue_kpi_non_media_metrics, + response_curves, + marketing_analysis_spec, + date_intervals, + ) + + def _generate_media_and_non_media_summary_metrics( + self, + media_summary_spec: MediaSummarySpec, + selected_times: list[str] | None, + confidence_level: float, + use_kpi: bool, + ) -> tuple[xr.Dataset | None, xr.Dataset | None]: + if media_summary_spec is None: + return (None, None) + compute_media_summary_metrics = functools.partial( + self._analyzer.summary_metrics, + marginal_roi_by_reach=media_summary_spec.marginal_roi_by_reach, + selected_times=selected_times, + aggregate_geos=True, + aggregate_times=media_summary_spec.aggregate_times, + confidence_level=confidence_level, + ) + + media_summary_metrics = compute_media_summary_metrics( + use_kpi=use_kpi, + include_non_paid_channels=False, + ).sel(distribution=constants.POSTERIOR) + # TODO:Produce one metrics for both paid and non-paid channels. + non_media_summary_metrics = None + if media_summary_spec.include_non_paid_channels: + media_summary_metrics = media_summary_metrics.drop_sel( + channel=constants.ALL_CHANNELS + ) + non_media_summary_metrics = ( + compute_media_summary_metrics( + use_kpi=use_kpi, + include_non_paid_channels=True, + ) + .sel(distribution=constants.POSTERIOR) + .drop_sel( + channel=media_summary_metrics.coords[constants.CHANNEL].data + ) + ) + return media_summary_metrics, non_media_summary_metrics + + def _generate_marketing_analyses_for_incremental_outcome_spec( + self, + marketing_analysis_spec: MarketingAnalysisSpec, + resolver: model_processor.DatedSpecResolver, + ) -> list[marketing_analysis_pb2.MarketingAnalysis]: + """Creates a list of `MarketingAnalysis` protos based on the given spec. + + If the spec's `aggregate_times` is True, then only one `MarketingAnalysis` + proto is created. Otherwise, one `MarketingAnalysis` proto is created for + each date interval in the spec. + + Args: + marketing_analysis_spec: An instance of MarketingAnalysisSpec. + resolver: A DatedSpecResolver instance. + + Returns: + A list of `MarketingAnalysis` protos containing the results of the + marketing analysis for the given spec. + """ + incremental_outcome_spec = marketing_analysis_spec.incremental_outcome_spec + if incremental_outcome_spec is None: + return [] + + compute_incremental_outcome = functools.partial( + self._incremental_outcome_dataset, + resolver=resolver, + new_data=incremental_outcome_spec.new_data, + media_selected_times=incremental_outcome_spec.media_selected_times, + aggregate_geos=True, + aggregate_times=incremental_outcome_spec.aggregate_times, + confidence_level=marketing_analysis_spec.confidence_level, + include_non_paid_channels=False, + ) + # This contains either a revenue-based KPI or a non-revenue KPI analysis. + incremental_outcome = compute_incremental_outcome(use_kpi=self._kpi_only) + + secondary_non_revenue_kpi_metrics = None + # If the input data KPI type is "non-revenue", and we calculated its + # revenue-based KPI outcomes above, then we should also compute its + # non-revenue KPI outcomes. + if not self._revenue_kpi_type and not self._kpi_only: + secondary_non_revenue_kpi_metrics = compute_incremental_outcome( + use_kpi=True + ) + + date_intervals = self._build_time_intervals( + aggregate_times=incremental_outcome_spec.aggregate_times, + resolver=resolver, + ) + + return self._marketing_metrics_to_protos( + metrics=incremental_outcome, + non_media_metrics=None, + baseline_outcome=None, + secondary_non_revenue_kpi_metrics=secondary_non_revenue_kpi_metrics, + secondary_non_revenue_kpi_non_media_metrics=None, + response_curves=None, + marketing_analysis_spec=marketing_analysis_spec, + date_intervals=date_intervals, + ) + + def _build_time_intervals( + self, + aggregate_times: bool, + resolver: model_processor.DatedSpecResolver, + ) -> list[date_interval_pb2.DateInterval]: + """Creates a list of `DateInterval` protos for the given spec. + + Args: + aggregate_times: Whether to aggregate times. + resolver: A DatedSpecResolver instance. + + Returns: + A list of `DateInterval` protos for the given spec. + """ + if aggregate_times: + date_interval = resolver.collapse_to_date_interval_proto() + # This means metrics are aggregated over time, only one date interval is + # needed. + return [date_interval] + + # This list will contain all date intervals for the given spec. All dates + # in this list will share a common tag. + return resolver.transform_to_date_interval_protos() + + def _marketing_metrics_to_protos( + self, + metrics: xr.Dataset, + non_media_metrics: xr.Dataset | None, + baseline_outcome: xr.Dataset | None, + secondary_non_revenue_kpi_metrics: xr.Dataset | None, + secondary_non_revenue_kpi_non_media_metrics: xr.Dataset | None, + response_curves: xr.Dataset | None, + marketing_analysis_spec: MarketingAnalysisSpec, + date_intervals: Sequence[date_interval_pb2.DateInterval], + ) -> list[marketing_analysis_pb2.MarketingAnalysis]: + """Creates a list of MarketingAnalysis protos from datasets.""" + if metrics is None: + raise ValueError("metrics is None") + + media_channels = list(metrics.coords[constants.CHANNEL].data) + non_media_channels = ( + list(non_media_metrics.coords[constants.CHANNEL].data) + if non_media_metrics + else [] + ) + channels = media_channels + non_media_channels + channels_with_response_curve = ( + response_curves.coords[constants.CHANNEL].data + if response_curves + else [] + ) + marketing_analyses = [] + for date_interval in date_intervals: + start_date = date_interval.start_date + start_date_str = datetime.date( + start_date.year, start_date.month, start_date.day + ).strftime(constants.DATE_FORMAT) + media_analyses: list[media_analysis_pb2.MediaAnalysis] = [] + non_media_analyses: list[non_media_analysis_pb2.NonMediaAnalysis] = [] + + # For all channels reported in the media summary metrics + for channel_name in channels: + channel_response_curve = None + if response_curves and (channel_name in channels_with_response_curve): + channel_response_curve = response_curves.sel( + {constants.CHANNEL: channel_name} + ) + is_media_channel = channel_name in media_channels + + channel_analysis = self._get_channel_metrics( + marketing_analysis_spec, + channel_name, + start_date_str, + metrics if is_media_channel else non_media_metrics, + secondary_non_revenue_kpi_metrics + if is_media_channel + else secondary_non_revenue_kpi_non_media_metrics, + channel_response_curve, + is_media_channel, + ) + if isinstance(channel_analysis, media_analysis_pb2.MediaAnalysis): + media_analyses.append(channel_analysis) + + if isinstance( + channel_analysis, non_media_analysis_pb2.NonMediaAnalysis + ): + non_media_analyses.append(channel_analysis) + + marketing_analysis = marketing_analysis_pb2.MarketingAnalysis( + date_interval=date_interval, + media_analyses=media_analyses, + non_media_analyses=non_media_analyses, + ) + if baseline_outcome is not None: + baseline_analysis = self._get_baseline_metrics( + marketing_analysis_spec=marketing_analysis_spec, + baseline_outcome=baseline_outcome, + start_date=start_date_str, + ) + marketing_analysis.non_media_analyses.append(baseline_analysis) + + marketing_analyses.append(marketing_analysis) + + return marketing_analyses + + def _get_channel_metrics( + self, + marketing_analysis_spec: MarketingAnalysisSpec, + channel_name: str, + start_date_str: str, + metrics: xr.Dataset, + secondary_metrics: xr.Dataset | None, + channel_response_curves: xr.Dataset | None, + is_media_channel: bool, + ) -> ( + media_analysis_pb2.MediaAnalysis | non_media_analysis_pb2.NonMediaAnalysis + ): + """Returns a MediaAnalysis proto for the given channel.""" + if constants.TIME in metrics.coords: + sel = { + constants.CHANNEL: channel_name, + constants.TIME: start_date_str, + } + else: + sel = {constants.CHANNEL: channel_name} + + channel_metrics = metrics.sel(sel) + if secondary_metrics is not None: + channel_secondary_metrics = secondary_metrics.sel(sel) + else: + channel_secondary_metrics = None + + return self._channel_metrics_to_proto( + channel_metrics, + channel_secondary_metrics, + channel_response_curves, + channel_name, + is_media_channel, + marketing_analysis_spec.confidence_level, + ) + + def _channel_metrics_to_proto( + self, + channel_media_summary_metrics: xr.Dataset, + channel_secondary_non_revenue_metrics: xr.Dataset | None, + channel_response_curve: xr.Dataset | None, + channel_name: str, + is_media_channel: bool, + confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL, + ) -> ( + media_analysis_pb2.MediaAnalysis | non_media_analysis_pb2.NonMediaAnalysis + ): + """Creates a MediaAnalysis proto for the given channel from datasets. + + Args: + channel_media_summary_metrics: A dataset containing the model's media + summary metrics. This dataset is pre-filtered to `channel_name`. This + dataset contains revenue-based metrics if the model's input data is + revenue-based, or if `revenue_per_kpi` is defined. Otherwise, it + contains non-revenue generic KPI metrics. + channel_secondary_non_revenue_metrics: A dataset containing the model's + non-revenue-based media summary metrics. This is only defined iff the + input data is non-revenue type AND `revenue_per_kpi` is available. In + this case, `channel_media_summary_metrics` contains revenue-based + metrics computed from `KPI * revenue_per_kpi`, and this dataset contains + media summary metrics based on the model's generic KPI alone. In all + other cases, this is `None`. + channel_response_curve: A dataset containing the data needed to generate a + response curve. This dataset is pre-filtered to `channel_name`. + channel_name: The name of the channel to analyze. + is_media_channel: Whether the channel is a media channel. + confidence_level: Confidence level for credible intervals, represented as + a value between zero and one. + + Returns: + A proto containing the media analysis results for the given channel. + """ + + spend_info = _compute_spend(channel_media_summary_metrics) + is_all_channels = channel_name == constants.ALL_CHANNELS + + compute_outcome = functools.partial( + self._compute_outcome, + is_all_channels=is_all_channels, + confidence_level=confidence_level, + ) + + outcomes = [ + compute_outcome( + channel_media_summary_metrics, + is_revenue_type=(not self._kpi_only), + ) + ] + # If `channel_media_summary_metrics` represented non-revenue data with + # revenue-type outcome (i.e. `is_revenue_type_kpi` is defined), then we + # should also have been provided with media summary metrics for their + # generic KPI counterparts, as well. + if channel_secondary_non_revenue_metrics is not None: + outcomes.append( + compute_outcome( + channel_secondary_non_revenue_metrics, + is_revenue_type=False, + ) + ) + + if not is_media_channel: + return non_media_analysis_pb2.NonMediaAnalysis( + non_media_name=channel_name, + non_media_outcomes=outcomes, + ) + + media_analysis = media_analysis_pb2.MediaAnalysis( + channel_name=channel_name, + media_outcomes=outcomes, + ) + + if spend_info is not None: + media_analysis.spend_info.CopyFrom(spend_info) + + if channel_response_curve is not None: + media_analysis.response_curve.CopyFrom( + self._compute_response_curve( + channel_response_curve, + ) + ) + + return media_analysis + + def _get_baseline_metrics( + self, + marketing_analysis_spec: MarketingAnalysisSpec, + baseline_outcome: xr.Dataset, + start_date: str, + ) -> non_media_analysis_pb2.NonMediaAnalysis: + """Analyzes "baseline" pseudo-channel outcomes over the given time points. + + Args: + marketing_analysis_spec: A user input parameter specs for this analysis. + baseline_outcome: A dataset containing the model's baseline summary + metrics. + start_date: The date of the analysis. + + Returns: + A `NonMediaAnalysis` representing baseline analysis. + """ + if constants.TIME in baseline_outcome.coords: + baseline_outcome = baseline_outcome.sel( + time=start_date, + ) + incremental_outcome = baseline_outcome[constants.BASELINE_OUTCOME] + # Convert percentage to decimal. + contribution_share = baseline_outcome[constants.PCT_OF_CONTRIBUTION] / 100 + + contribution = outcome_pb2.Contribution( + value=common.to_estimate( + incremental_outcome, marketing_analysis_spec.confidence_level + ), + share=common.to_estimate( + contribution_share, marketing_analysis_spec.confidence_level + ), + ) + baseline_analysis = non_media_analysis_pb2.NonMediaAnalysis( + non_media_name=constants.BASELINE, + ) + baseline_outcome = outcome_pb2.Outcome( + contribution=contribution, + # Baseline outcome is always revenue-based, unless `revenue_per_kpi` + # is undefined. + # TODO: kpi_type here is synced with what is used inside + # `baseline_summary_metrics()`. Ideally, really, we should inject this + # value into that function rather than re-deriving it here. + kpi_type=( + kpi_type_pb2.KpiType.NON_REVENUE + if self._kpi_only + else kpi_type_pb2.KpiType.REVENUE + ), + ) + baseline_analysis.non_media_outcomes.append(baseline_outcome) + + return baseline_analysis + + def _compute_outcome( + self, + media_summary_metrics: xr.Dataset, + is_revenue_type: bool, + is_all_channels: bool, + confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL, + ) -> outcome_pb2.Outcome: + """Returns an `Outcome` proto for the given channel's media analysis. + + Args: + media_summary_metrics: A dataset containing the model's media summary + metrics. + is_revenue_type: Whether the media summary metrics above are revenue + based. + is_all_channels: If True, the given media summary represents the aggregate + "All Channels". Omit `effectiveness` and `mroi` in this case. + confidence_level: Confidence level for credible intervals, represented as + a value between zero and one. + """ + data_vars = media_summary_metrics.data_vars + + effectiveness = roi = mroi = cpik = None + if not is_all_channels and constants.EFFECTIVENESS in data_vars: + effectiveness = outcome_pb2.Effectiveness( + media_unit=constants.IMPRESSIONS, + value=common.to_estimate( + media_summary_metrics[constants.EFFECTIVENESS], + confidence_level, + ), + ) + if not is_all_channels and constants.MROI in data_vars: + mroi = common.to_estimate( + media_summary_metrics[constants.MROI], + confidence_level, + ) + + contribution_value = media_summary_metrics[constants.INCREMENTAL_OUTCOME] + contribution = outcome_pb2.Contribution( + value=common.to_estimate( + contribution_value, + confidence_level, + ), + ) + # Convert percentage to decimal. + if constants.PCT_OF_CONTRIBUTION in data_vars: + contribution_share = ( + media_summary_metrics[constants.PCT_OF_CONTRIBUTION] / 100 + ) + contribution.share.CopyFrom( + common.to_estimate( + contribution_share, + confidence_level, + ) + ) + + if constants.CPIK in data_vars: + cpik = common.to_estimate( + media_summary_metrics[constants.CPIK], + confidence_level, + metric=constants.MEDIAN, + ) + + if constants.ROI in data_vars: + roi = common.to_estimate( + media_summary_metrics[constants.ROI], + confidence_level, + ) + + return outcome_pb2.Outcome( + kpi_type=( + kpi_type_pb2.KpiType.REVENUE + if is_revenue_type + else kpi_type_pb2.KpiType.NON_REVENUE + ), + contribution=contribution, + effectiveness=effectiveness, + cost_per_contribution=cpik, + roi=roi, + marginal_roi=mroi, + ) + + def _compute_response_curve( + self, + response_curve_dataset: xr.Dataset, + ) -> response_curve_pb2.ResponseCurve: + """Returns a `ResponseCurve` proto for the given channel. + + Args: + response_curve_dataset: A dataset containing the data needed to generate a + response curve. + """ + + spend_multiplier_list = response_curve_dataset.coords[ + constants.SPEND_MULTIPLIER + ].data + response_points: list[response_curve_pb2.ResponsePoint] = [] + + for spend_multiplier in spend_multiplier_list: + spend = ( + response_curve_dataset[constants.SPEND] + .sel(spend_multiplier=spend_multiplier) + .data.item() + ) + incremental_outcome = ( + response_curve_dataset[constants.INCREMENTAL_OUTCOME] + .sel( + spend_multiplier=spend_multiplier, + metric=constants.MEAN, + ) + .data.item() + ) + + response_point = response_curve_pb2.ResponsePoint( + input_value=spend, + incremental_kpi=incremental_outcome, + ) + response_points.append(response_point) + + return response_curve_pb2.ResponseCurve( + input_name=constants.SPEND, + response_points=response_points, + ) + + # TODO: Create an abstraction/container around these inference + # parameters. + def _incremental_outcome_dataset( + self, + resolver: model_processor.DatedSpecResolver, + new_data: analyzer.DataTensors | None = None, + media_selected_times: Sequence[bool] | None = None, + selected_geos: Sequence[str] | None = None, + aggregate_geos: bool = True, + aggregate_times: bool = True, + use_kpi: bool = False, + confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL, + batch_size: int = constants.DEFAULT_BATCH_SIZE, + include_non_paid_channels: bool = False, + ) -> xr.Dataset: + """Returns incremental outcome for each channel with dimensions. + + Args: + resolver: A `DatedSpecResolver` instance. + new_data: A dataset containing the new data to use in the analysis. + media_selected_times: A boolean array of length `n_times` indicating which + time periods are media-active. + selected_geos: Optional list containing a subset of geos to include. By + default, all geos are included. + aggregate_geos: Boolean. If `True`, the expected outcome is summed over + all of the regions. + aggregate_times: Boolean. If `True`, the expected outcome is summed over + all of the time periods. + use_kpi: Boolean. If `True`, the summary metrics are calculated using KPI. + If `False`, the metrics are calculated using revenue. + confidence_level: Confidence level for summary metrics credible intervals, + represented as a value between zero and one. + batch_size: Integer representing the maximum draws per chain in each + batch. The calculation is run in batches to avoid memory exhaustion. If + a memory error occurs, try reducing `batch_size`. The calculation will + generally be faster with larger `batch_size` values. + include_non_paid_channels: Boolean. If `True`, non-paid channels (organic + media, organic reach and frequency, and non-media treatments) are + included in the summary but only the metrics independent of spend are + reported. If `False`, only the paid channels (media, reach and + frequency) are included but the summary contains also the metrics + dependent on spend. Default: `False`. + + Returns: + An `xr.Dataset` and containing `incremental_outcome` for each channel. The + coordinates are: `channel` and `metric` (`mean`, `median`, `ci_low`, + `ci_high`) + """ + # Selected times in boolean form are supported by the analyzer with and + # without the new data. + selected_times_bool = resolver.resolve_to_bool_selected_times() + kwargs = { + "selected_geos": selected_geos, + "selected_times": selected_times_bool, + "aggregate_geos": aggregate_geos, + "aggregate_times": aggregate_times, + "batch_size": batch_size, + } + incremental_outcome_posterior = ( + self._analyzer.compute_incremental_outcome_aggregate( + new_data=new_data, + media_selected_times=media_selected_times, + use_posterior=True, + use_kpi=use_kpi, + include_non_paid_channels=include_non_paid_channels, + **kwargs, + ) + ) + + xr_dims = ( + ((constants.GEO,) if not aggregate_geos else ()) + + ((constants.TIME,) if not aggregate_times else ()) + + (constants.CHANNEL, constants.METRIC) + ) + channels = ( + self._meridian.input_data.get_all_channels() + if include_non_paid_channels + else self._meridian.input_data.get_all_paid_channels() + ) + xr_coords = { + constants.CHANNEL: ( + [constants.CHANNEL], + list(channels) + [constants.ALL_CHANNELS], + ), + } + if not aggregate_geos: + geo_dims = ( + self._meridian.input_data.geo.data + if selected_geos is None + else selected_geos + ) + xr_coords[constants.GEO] = ([constants.GEO], geo_dims) + if not aggregate_times: + selected_times_str = resolver.resolve_to_enumerated_selected_times() + if selected_times_str is not None: + time_dims = selected_times_str + else: + time_dims = resolver.time_coordinates.all_dates_str + xr_coords[constants.TIME] = ([constants.TIME], time_dims) + xr_coords_with_ci = { + constants.METRIC: ( + [constants.METRIC], + [ + constants.MEAN, + constants.MEDIAN, + constants.CI_LO, + constants.CI_HI, + ], + ), + **xr_coords, + } + metrics = analyzer.get_central_tendency_and_ci( + incremental_outcome_posterior, confidence_level, include_median=True + ) + xr_data = {constants.INCREMENTAL_OUTCOME: (xr_dims, metrics)} + return xr.Dataset(data_vars=xr_data, coords=xr_coords_with_ci) + + +def _compute_spend( + media_summary_metrics: xr.Dataset, +) -> media_analysis_pb2.SpendInfo | None: + """Returns a `SpendInfo` proto with spend information for the given channel. + + Args: + media_summary_metrics: A dataset containing the model's media summary + metrics. + """ + if constants.SPEND not in media_summary_metrics.data_vars: + return None + + spend = media_summary_metrics[constants.SPEND].item() + spend_share = media_summary_metrics[constants.PCT_OF_SPEND].data.item() / 100 + + return media_analysis_pb2.SpendInfo( + spend=spend, + spend_share=spend_share, + ) diff --git a/schema/processors/model_fit_processor.py b/schema/processors/model_fit_processor.py new file mode 100644 index 000000000..35691a3dc --- /dev/null +++ b/schema/processors/model_fit_processor.py @@ -0,0 +1,367 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Meridian module for analyzing model fit in a Meridian model. + +This module provides a `ModelFitProcessor`, which assesses the goodness of fit +of a trained Meridian model. It compares the model's predictions against the +actual observed data, generating key performance metrics. + +Key metrics generated include R-squared, MAPE, and Weighted MAPE. The output +also includes timeseries data of actual values versus predicted values (with +confidence intervals) and the predicted baseline. + +The results are structured into a `ModelFit` protobuf message. + +Key Classes: + +- `ModelFitSpec`: Dataclass to specify parameters for the model fit analysis, + such as whether to split by train/test sets and the confidence level for + intervals. +- `ModelFitProcessor`: The processor class that performs the fit analysis. + +Example Usage: + +```python +from schema.processors import model_fit_processor +from schema.processors import model_processor + +# Assuming 'mmm' is a trained Meridian model object +trained_model = model_processor.TrainedModel(mmm) + +# Default spec: split results by train/test if holdout ID exists +spec = model_fit_processor.ModelFitSpec() + +processor = model_fit_processor.ModelFitProcessor(trained_model) +# result is a model_fit_pb2.ModelFit proto +result = processor.execute([spec]) + +print("Model Fit Analysis Results:") +for res in result.results: + print(f" Dataset: {res.name}") + print(f" R-squared: {res.performance.r_squared:.3f}") + print(f" MAPE: {res.performance.mape:.3f}") + print(f" Weighted MAPE: {res.performance.weighted_mape:.3f}") + # Prediction data is available in res.predictions + # Each element in res.predictions corresponds to a time point. + # e.g., res.predictions[0].actual_value + # e.g., res.predictions[0].predicted_outcome.value +``` + +Note: Only one spec is supported per processor execution. +""" + +from collections.abc import Mapping, Sequence +import dataclasses +import warnings + +from meridian import constants +from mmm.v1 import mmm_pb2 +from mmm.v1.common import date_interval_pb2 +from mmm.v1.common import estimate_pb2 +from mmm.v1.fit import model_fit_pb2 +from schema.processors import model_processor +from schema.utils import time_record +import xarray as xr + + +__all__ = [ + "ModelFitSpec", + "ModelFitProcessor", +] + + +@dataclasses.dataclass(frozen=True) +class ModelFitSpec(model_processor.Spec): + """Stores parameters needed for generating ModelFit protos. + + Attributes: + split: If `True` and Meridian model contains holdout IDs, results are + generated for `'Train'`, `'Test'`, and `'All Data'` sets. + confidence_level: Confidence level for prior and posterior credible + intervals, represented as a value between zero and one. + """ + + split: bool = True + confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL + + def validate(self): + if self.confidence_level <= 0 or self.confidence_level >= 1: + raise ValueError( + "Confidence level must be greater than 0 and less than 1." + ) + + +class ModelFitProcessor( + model_processor.ModelProcessor[ModelFitSpec, model_fit_pb2.ModelFit] +): + """Generates a ModelFit proto for a given trained Meridian model. + + The proto contains performance metrics for each dataset as well as a list of + predictions. + """ + + def __init__( + self, + trained_model: model_processor.ModelType, + ): + trained_model = model_processor.ensure_trained_model(trained_model) + self._analyzer = trained_model.internal_analyzer + self._time_coordinates = trained_model.time_coordinates + + @classmethod + def spec_type(cls) -> type[ModelFitSpec]: + return ModelFitSpec + + @classmethod + def output_type(cls) -> type[model_fit_pb2.ModelFit]: + return model_fit_pb2.ModelFit + + def _set_output(self, output: mmm_pb2.Mmm, result: model_fit_pb2.ModelFit): + output.model_fit.CopyFrom(result) + + def execute(self, specs: Sequence[ModelFitSpec]) -> model_fit_pb2.ModelFit: + model_fit_spec = specs[0] + if len(specs) > 1: + warnings.warn( + "Multiple specs were provided. Only the first one will be used." + ) + + expected_vs_actual = self._analyzer.expected_vs_actual_data( + confidence_level=model_fit_spec.confidence_level, + split_by_holdout_id=model_fit_spec.split, + aggregate_geos=True, + ) + metrics = self._analyzer.predictive_accuracy() + time_to_date_interval = time_record.convert_times_to_date_intervals( + self._time_coordinates.datetime_index + ) + + results: list[model_fit_pb2.Result] = [] + + if constants.EVALUATION_SET_VAR in expected_vs_actual.coords: + results.append( + self._create_result( + result_type=constants.TRAIN, + expected_vs_actual=expected_vs_actual.sel( + evaluation_set=constants.TRAIN + ), + metrics=metrics.sel(evaluation_set=constants.TRAIN), + model_fit_spec=model_fit_spec, + time_to_date_interval=time_to_date_interval, + ) + ) + results.append( + self._create_result( + result_type=constants.TEST, + expected_vs_actual=expected_vs_actual.sel( + evaluation_set=constants.TEST + ), + metrics=metrics.sel(evaluation_set=constants.TEST), + model_fit_spec=model_fit_spec, + time_to_date_interval=time_to_date_interval, + ) + ) + results.append( + self._create_result( + result_type=constants.ALL_DATA, + expected_vs_actual=expected_vs_actual.sel( + evaluation_set=constants.ALL_DATA + ), + metrics=metrics.sel(evaluation_set=constants.ALL_DATA), + model_fit_spec=model_fit_spec, + time_to_date_interval=time_to_date_interval, + ) + ) + else: + results.append( + self._create_result( + result_type=constants.ALL_DATA, + expected_vs_actual=expected_vs_actual, + metrics=metrics, + model_fit_spec=model_fit_spec, + time_to_date_interval=time_to_date_interval, + ) + ) + + return model_fit_pb2.ModelFit(results=results) + + def _create_result( + self, + result_type: str, + expected_vs_actual: xr.Dataset, + metrics: xr.Dataset, + model_fit_spec: ModelFitSpec, + time_to_date_interval: Mapping[str, date_interval_pb2.DateInterval], + ) -> model_fit_pb2.Result: + """Creates a proto that stores the model fit results for an evaluation set. + + Args: + result_type: The evaluation set (`"Train"`, `"Test"`, or `"All Data"`) for + the result. + expected_vs_actual: A dataset containing the expected and actual values + for the model. This dataset is filtered by the evaluation set in the + calling code. + metrics: A dataset containing the performance metrics for the model. This + dataset is filtered by the evaluation set in the calling code. + model_fit_spec: An instance of ModelFitSpec. + time_to_date_interval: A mapping of date strings (in YYYY-MM-DD format) to + date interval protos. + + Returns: + A proto containing the results of the model fit analysis for the given + evaluation set. + """ + + predictions: list[model_fit_pb2.Prediction] = [] + + for start_date in self._time_coordinates.all_dates_str: + date_interval = time_to_date_interval[start_date] + actual = ( + expected_vs_actual.data_vars[constants.ACTUAL] + .sel( + time=start_date, + ) + .item() + ) + expected_dataset = expected_vs_actual[constants.EXPECTED].sel( + time=start_date, + ) + expected = expected_dataset.sel(metric=constants.MEAN).item() + expected_lowerbound = expected_dataset.sel(metric=constants.CI_LO).item() + expected_upperbound = expected_dataset.sel(metric=constants.CI_HI).item() + baseline_dataset = expected_vs_actual[constants.BASELINE].sel( + time=start_date, + ) + baseline = baseline_dataset.sel(metric=constants.MEAN).item() + baseline_lowerbound = baseline_dataset.sel(metric=constants.CI_LO).item() + baseline_upperbound = baseline_dataset.sel(metric=constants.CI_HI).item() + + prediction = self._create_prediction( + model_fit_spec=model_fit_spec, + date_interval=date_interval, + actual_value=actual, + estimated_value=expected, + estimated_lower_bound=expected_lowerbound, + estimated_upper_bound=expected_upperbound, + baseline_value=baseline, + baseline_lower_bound=baseline_lowerbound, + baseline_upper_bound=baseline_upperbound, + ) + predictions.append(prediction) + + performance = self._evaluate_model_fit(metrics) + + return model_fit_pb2.Result( + name=result_type, predictions=predictions, performance=performance + ) + + def _create_prediction( + self, + model_fit_spec: ModelFitSpec, + date_interval: date_interval_pb2.DateInterval, + actual_value: float, + estimated_value: float, + estimated_lower_bound: float, + estimated_upper_bound: float, + baseline_value: float, + baseline_lower_bound: float, + baseline_upper_bound: float, + ) -> model_fit_pb2.Prediction: + """Creates a proto that stores the model's prediction for the given date. + + Args: + model_fit_spec: An instance of ModelFitSpec. + date_interval: A DateInterval proto containing the start date and end date + for this prediction. + actual_value: The model's actual value for this date. + estimated_value: The model's estimated value for this date. + estimated_lower_bound: The lower bound of the estimated value's confidence + interval. + estimated_upper_bound: The upper bound of the estimated value's confidence + interval. + baseline_value: The baseline value for this date. + baseline_lower_bound: The lower bound of the baseline value's confidence + interval. + baseline_upper_bound: The upper bound of the baseline value's confidence + interval. + + Returns: + A proto containing the model's predicted value and actual value for the + given date. + """ + + estimate = estimate_pb2.Estimate(value=estimated_value) + estimate.uncertainties.add( + probability=model_fit_spec.confidence_level, + lowerbound=estimated_lower_bound, + upperbound=estimated_upper_bound, + ) + + baseline_estimate = estimate_pb2.Estimate(value=baseline_value) + baseline_estimate.uncertainties.add( + probability=model_fit_spec.confidence_level, + lowerbound=baseline_lower_bound, + upperbound=baseline_upper_bound, + ) + + return model_fit_pb2.Prediction( + date_interval=date_interval, + predicted_outcome=estimate, + predicted_baseline=baseline_estimate, + actual_value=actual_value, + ) + + def _evaluate_model_fit( + self, + metrics: xr.Dataset, + ) -> model_fit_pb2.Performance: + """Creates a proto that stores the model's performance metrics. + + Args: + metrics: A dataset containing the performance metrics for the model. This + dataset is filtered by evaluation set before this function is called. + + Returns: + A proto containing the model's performance metrics for a specific + evaluation set. + """ + + performance = model_fit_pb2.Performance() + performance.r_squared = ( + metrics[constants.VALUE] + .sel( + geo_granularity=constants.NATIONAL, + metric=constants.R_SQUARED, + ) + .item() + ) + performance.mape = ( + metrics[constants.VALUE] + .sel( + geo_granularity=constants.NATIONAL, + metric=constants.MAPE, + ) + .item() + ) + performance.weighted_mape = ( + metrics[constants.VALUE] + .sel( + geo_granularity=constants.NATIONAL, + metric=constants.WMAPE, + ) + .item() + ) + + return performance diff --git a/schema/processors/model_fit_processor_test.py b/schema/processors/model_fit_processor_test.py new file mode 100644 index 000000000..9e2a68998 --- /dev/null +++ b/schema/processors/model_fit_processor_test.py @@ -0,0 +1,503 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for model_fit_processor.py.""" + +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from meridian import constants +from meridian.analysis import analyzer +from meridian.data import time_coordinates as tc +from meridian.model import model +from mmm.v1.common import date_interval_pb2 +from mmm.v1.common import estimate_pb2 +from mmm.v1.fit import model_fit_pb2 +from schema.processors import model_fit_processor +from schema.processors import model_processor +import numpy as np +import xarray as xr + +from google.type import date_pb2 +from tensorflow.python.util.protobuf import compare + + +_ALL_TIMES = xr.DataArray( + np.array([ + "2024-01-01", + "2024-01-08", + "2024-01-15", + ]) +) + +_EXPECTED_RESULT_PROTO_TRAIN = model_fit_pb2.Result( + name=constants.TRAIN, + performance=model_fit_pb2.Performance( + r_squared=0.9, mape=0.88, weighted_mape=0.95 + ), + predictions=[ + model_fit_pb2.Prediction( + date_interval=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=1), + end_date=date_pb2.Date(year=2024, month=1, day=8), + ), + predicted_outcome=estimate_pb2.Estimate( + value=0.75, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.62, upperbound=0.96 + ) + ], + ), + predicted_baseline=estimate_pb2.Estimate( + value=0.65, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.52, upperbound=0.86 + ) + ], + ), + actual_value=0.75, + ), + model_fit_pb2.Prediction( + date_interval=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=8), + end_date=date_pb2.Date(year=2024, month=1, day=15), + ), + predicted_outcome=estimate_pb2.Estimate( + value=0.7, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.6, upperbound=0.95 + ) + ], + ), + predicted_baseline=estimate_pb2.Estimate( + value=0.6, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.5, upperbound=0.85 + ) + ], + ), + actual_value=0.7, + ), + model_fit_pb2.Prediction( + date_interval=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=15), + end_date=date_pb2.Date(year=2024, month=1, day=22), + ), + predicted_outcome=estimate_pb2.Estimate( + value=0.85, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.75, upperbound=0.97 + ) + ], + ), + predicted_baseline=estimate_pb2.Estimate( + value=0.75, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.65, upperbound=0.87 + ) + ], + ), + actual_value=0.85, + ), + ], +) + +_EXPECTED_RESULT_PROTO_TEST = model_fit_pb2.Result( + name=constants.TEST, + performance=model_fit_pb2.Performance( + r_squared=0.74, mape=0.68, weighted_mape=0.83 + ), + predictions=[ + model_fit_pb2.Prediction( + date_interval=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=1), + end_date=date_pb2.Date(year=2024, month=1, day=8), + ), + predicted_outcome=estimate_pb2.Estimate( + value=0.75, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.65, upperbound=0.86 + ) + ], + ), + predicted_baseline=estimate_pb2.Estimate( + value=0.65, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.55, upperbound=0.76 + ) + ], + ), + actual_value=0.62, + ), + model_fit_pb2.Prediction( + date_interval=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=8), + end_date=date_pb2.Date(year=2024, month=1, day=15), + ), + predicted_outcome=estimate_pb2.Estimate( + value=0.65, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.6, upperbound=0.84 + ) + ], + ), + predicted_baseline=estimate_pb2.Estimate( + value=0.55, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.5, upperbound=0.74 + ) + ], + ), + actual_value=0.6, + ), + model_fit_pb2.Prediction( + date_interval=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=15), + end_date=date_pb2.Date(year=2024, month=1, day=22), + ), + predicted_outcome=estimate_pb2.Estimate( + value=0.85, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.7, upperbound=0.88 + ) + ], + ), + predicted_baseline=estimate_pb2.Estimate( + value=0.75, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.6, upperbound=0.78 + ) + ], + ), + actual_value=0.75, + ), + ], +) + +_EXPECTED_RESULT_PROTO_ALL_DATA = model_fit_pb2.Result( + name=constants.ALL_DATA, + performance=model_fit_pb2.Performance( + r_squared=0.91, mape=0.87, weighted_mape=0.98 + ), + predictions=[ + model_fit_pb2.Prediction( + date_interval=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=1), + end_date=date_pb2.Date(year=2024, month=1, day=8), + ), + predicted_outcome=estimate_pb2.Estimate( + value=0.9, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.83, upperbound=0.71 + ) + ], + ), + predicted_baseline=estimate_pb2.Estimate( + value=0.8, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.73, upperbound=0.61 + ) + ], + ), + actual_value=0.96, + ), + model_fit_pb2.Prediction( + date_interval=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=8), + end_date=date_pb2.Date(year=2024, month=1, day=15), + ), + predicted_outcome=estimate_pb2.Estimate( + value=0.83, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.75, upperbound=0.65 + ) + ], + ), + predicted_baseline=estimate_pb2.Estimate( + value=0.73, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.65, upperbound=0.55 + ) + ], + ), + actual_value=0.95, + ), + model_fit_pb2.Prediction( + date_interval=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=15), + end_date=date_pb2.Date(year=2024, month=1, day=22), + ), + predicted_outcome=estimate_pb2.Estimate( + value=0.96, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.9, upperbound=0.77 + ) + ], + ), + predicted_baseline=estimate_pb2.Estimate( + value=0.86, + uncertainties=[ + estimate_pb2.Estimate.Uncertainty( + probability=0.9, lowerbound=0.8, upperbound=0.67 + ) + ], + ), + actual_value=0.97, + ), + ], +) + + +def _create_expected_model_fit(split: bool) -> model_fit_pb2.ModelFit: + if split: + return model_fit_pb2.ModelFit( + results=[ + _EXPECTED_RESULT_PROTO_TRAIN, + _EXPECTED_RESULT_PROTO_TEST, + _EXPECTED_RESULT_PROTO_ALL_DATA, + ] + ) + else: + return model_fit_pb2.ModelFit( + results=[ + _EXPECTED_RESULT_PROTO_ALL_DATA, + ] + ) + + +def _create_expected_vs_actual_data(split: bool) -> xr.Dataset: + xr_dims_expected = ( + constants.TIME, + constants.METRIC, + ) + ((constants.EVALUATION_SET_VAR,) if split else ()) + xr_dims_baseline = xr_dims_expected + xr_dims_actual = (constants.TIME,) + ( + (constants.EVALUATION_SET_VAR,) if split else () + ) + xr_coords = { + constants.TIME: ( + [constants.TIME], + _ALL_TIMES.data, + ), + constants.METRIC: ( + [constants.METRIC], + [constants.MEAN, constants.CI_LO, constants.CI_HI], + ), + } + if split: + xr_coords.update({ + constants.EVALUATION_SET_VAR: ( + [constants.EVALUATION_SET_VAR], + list(constants.EVALUATION_SET), + ) + }) + + time_1_train = [0.75, 0.7, 0.85] + time_1_test = [0.75, 0.65, 0.85] + time_1_all_data = [0.9, 0.83, 0.96] + + time_2_train = [0.62, 0.6, 0.75] + time_2_test = [0.65, 0.6, 0.7] + time_2_all_data = [0.83, 0.75, 0.9] + + time_3_train = [0.96, 0.95, 0.97] + time_3_test = [0.86, 0.84, 0.88] + time_3_all_data = [0.71, 0.65, 0.77] + + stacked_train = np.stack([time_1_train, time_2_train, time_3_train], axis=-1) + stacked_test = np.stack([time_1_test, time_2_test, time_3_test], axis=-1) + stacked_all_data = np.stack( + [time_1_all_data, time_2_all_data, time_3_all_data], axis=-1 + ) + stacked_total = np.stack( + [stacked_train, stacked_test, stacked_all_data], + axis=-1, + ) + + xr_data = { + constants.EXPECTED: ( + xr_dims_expected, + stacked_total if split else stacked_all_data, + ), + constants.BASELINE: ( + xr_dims_baseline, + (stacked_total if split else stacked_all_data) - 0.1, + ), + constants.ACTUAL: ( + xr_dims_actual, + stacked_train if split else time_3_train, + ), + } + + return xr.Dataset(data_vars=xr_data, coords=xr_coords) + + +def _create_predictive_accuracy_data(split: bool) -> xr.Dataset: + xr_dims = ( + constants.METRIC, + constants.GEO_GRANULARITY, + ) + ((constants.EVALUATION_SET_VAR,) if split else ()) + xr_coords = { + constants.METRIC: ( + [constants.METRIC], + [constants.R_SQUARED, constants.MAPE, constants.WMAPE], + ), + constants.GEO_GRANULARITY: ( + [constants.GEO_GRANULARITY], + [constants.GEO, constants.NATIONAL], + ), + } + if split: + xr_coords.update({ + constants.EVALUATION_SET_VAR: ( + [constants.EVALUATION_SET_VAR], + list(constants.EVALUATION_SET), + ) + }) + + geo_train = [0.8, 0.75, 0.85] + national_train = [0.9, 0.88, 0.95] + geo_test = [0.75, 0.65, 0.85] + national_test = [0.74, 0.68, 0.83] + geo_all_data = [0.93, 0.9, 0.96] + national_all_data = [0.91, 0.87, 0.98] + + stacked_train = np.stack([geo_train, national_train], axis=-1) + stacked_test = np.stack([geo_test, national_test], axis=-1) + stacked_all_data = np.stack([geo_all_data, national_all_data], axis=-1) + stacked_total = np.stack( + [stacked_train, stacked_test, stacked_all_data], axis=-1 + ) + + xr_data = { + constants.VALUE: (xr_dims, stacked_total if split else stacked_all_data) + } + + return xr.Dataset(data_vars=xr_data, coords=xr_coords) + + +class ModelFitSpecTest(absltest.TestCase): + + def test_confidence_level_is_below_zero(self): + with self.assertRaisesRegex( + ValueError, + "Confidence level must be greater than 0 and less than 1.", + ): + spec = model_fit_processor.ModelFitSpec(confidence_level=-1) + spec.validate() + + def test_confidence_level_is_above_one(self): + with self.assertRaisesRegex( + ValueError, + "Confidence level must be greater than 0 and less than 1.", + ): + spec = model_fit_processor.ModelFitSpec(confidence_level=1.5) + spec.validate() + + def test_validates_successfully(self): + spec = model_fit_processor.ModelFitSpec(split=True, confidence_level=0.95) + + spec.validate() + + self.assertEqual(spec.split, True) + self.assertEqual(spec.confidence_level, 0.95) + + +class ModelFitProcessorTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + + self.mock_meridian_model = self.enter_context( + mock.patch.object(model, "Meridian", autospec=True) + ) + self.mock_meridian_model.input_data.time = _ALL_TIMES + + self.mock_analyzer = self.enter_context( + mock.patch.object(analyzer, "Analyzer", autospec=True) + ) + + self.mock_trained_model = self.enter_context( + mock.patch.object(model_processor, "TrainedModel", autospec=True) + ) + self.mock_trained_model.mmm = self.mock_meridian_model + self.mock_trained_model.internal_analyzer = self.mock_analyzer + self.mock_trained_model.time_coordinates = tc.TimeCoordinates.from_dates( + _ALL_TIMES + ) + + self.mock_ensure_trained_model = self.enter_context( + mock.patch.object( + model_processor, "ensure_trained_model", autospec=True + ) + ) + self.mock_ensure_trained_model.return_value = self.mock_trained_model + + def test_spec_type_returns_model_fit_spec(self): + self.assertEqual( + model_fit_processor.ModelFitProcessor.spec_type(), + model_fit_processor.ModelFitSpec, + ) + + def test_output_type_returns_model_fit_proto(self): + self.assertEqual( + model_fit_processor.ModelFitProcessor.output_type(), + model_fit_pb2.ModelFit, + ) + + @parameterized.named_parameters( + dict( + testcase_name="multiple_evaluation_sets", + split=True, + ), + dict(testcase_name="single_evaluation_set", split=False), + ) + def test_execute(self, split: bool): + self.mock_analyzer.expected_vs_actual_data.return_value = ( + _create_expected_vs_actual_data(split) + ) + self.mock_analyzer.predictive_accuracy.return_value = ( + _create_predictive_accuracy_data(split) + ) + + spec = model_fit_processor.ModelFitSpec(split) + model_fit = model_fit_processor.ModelFitProcessor( + trained_model=self.mock_trained_model, + ).execute([spec]) + + compare.assertProtoEqual( + self, model_fit, _create_expected_model_fit(split) + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/schema/processors/model_kernel_processor.py b/schema/processors/model_kernel_processor.py new file mode 100644 index 000000000..2d8cc5508 --- /dev/null +++ b/schema/processors/model_kernel_processor.py @@ -0,0 +1,117 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for transforming a Meridian model into a structured MMM schema. + +This module provides the `ModelKernelProcessor`, which is responsible for +transforming the internal state of a trained Meridian model object into a +structured and portable format defined by the `MmmKernel` protobuf message. + +The "kernel" includes essential information about the model, such as: + +- Model specifications and hyperparameters. +- Inferred parameters distributions (as a serialized ArViz inference data). +- MMM-agnostic marketing data (i.e. input data to the model). + +This serialized representation allows the model to be saved, loaded, and +analyzed across different environments or by other tools that understand the +`MmmKernel` schema. + +The serialization logic is primarily handled by the `MeridianSerde` class from +the `schema.serde` package. + +Key Classes: + +- `ModelKernelProcessor`: The processor class that takes a Meridian model + instance and populates an `MmmKernel` message. + +Example Usage: + +```python +import meridian +from meridian.model import model +from mmm.v1 import mmm_pb2 +from schema.processors import model_kernel_processor +import semver + +# Assuming 'mmm' is a `meridian.model.Meridian` object. +# Example: +# mmm = meridian.model.Meridian(...) +# mmm.sample_prior(...) +# mmm.sample_posterior(...) + +processor = model_kernel_processor.ModelKernelProcessor( + meridian_model=mmm, + model_id="my_model_v1", +) + +# Create an output Mmm proto message +output_proto = mmm_pb2.Mmm() + +# Populate the mmm_kernel field +processor(output_proto) + +# Now output_proto.mmm_kernel contains the serialized model. +# This can be saved to a file, sent over a network, etc. +print(f"Model Kernel ID: {output_proto.mmm_kernel.model_id}") +print(f"Meridian Version: {output_proto.mmm_kernel.meridian_version}") +# Access other fields within output_proto.mmm_kernel as needed. +``` +""" + +import abc + +import meridian +from meridian.model import model +from mmm.v1 import mmm_pb2 as pb +from schema.serde import meridian_serde +import semver + + +class ModelKernelProcessor(abc.ABC): + """Transcribes a model's stats into an `"MmmKernel` message.""" + + def __init__( + self, + meridian_model: model.Meridian, + model_id: str = '', + meridian_version: semver.VersionInfo = semver.VersionInfo.parse( + meridian.__version__ + ), + ): + """Initializes this `ModelKernelProcessor` with a Meridian model. + + Args: + meridian_model: A Meridian model. + model_id: An optional model identifier unique to the given model. + meridian_version: The version of current Meridian framework. + """ + self._meridian = meridian_model + self._model_id = model_id + self._meridian_version = meridian_version + + def __call__(self, output: pb.Mmm): + """Sets `mmm_kernel` field in the given `Mmm` proto. + + Args: + output: The output proto to modify. + """ + output.mmm_kernel.CopyFrom( + meridian_serde.MeridianSerde().serialize( + self._meridian, + self._model_id, + self._meridian_version, + include_convergence_info=True, + ) + ) diff --git a/schema/processors/model_kernel_processor_test.py b/schema/processors/model_kernel_processor_test.py new file mode 100644 index 000000000..20f9aac0d --- /dev/null +++ b/schema/processors/model_kernel_processor_test.py @@ -0,0 +1,59 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import meridian +from meridian.model import model +from mmm.v1 import mmm_pb2 as pb +from mmm.v1.model import mmm_kernel_pb2 as kernel_pb +from schema.processors import model_kernel_processor +from schema.serde import meridian_serde +import semver + +from tensorflow.python.util.protobuf import compare + + +class ModelKernelProcessorTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + + self._model_id = 'test_model' + self._meridian_version = semver.VersionInfo.parse(meridian.__version__) + self._mock_meridian = mock.MagicMock(spec=model.Meridian) + self._processor = model_kernel_processor.ModelKernelProcessor( + meridian_model=self._mock_meridian, + model_id=self._model_id, + ) + + @mock.patch.object(meridian_serde.MeridianSerde, 'serialize') + def test_call(self, mock_serialize): + mock_serialize.return_value = kernel_pb.MmmKernel() + output = pb.Mmm() + self._processor(output) + self.assertTrue(output.HasField('mmm_kernel')) + mock_serialize.assert_called_once_with( + self._mock_meridian, + self._model_id, + self._meridian_version, + include_convergence_info=True, + ) + compare.assertProtoEqual(self, output.mmm_kernel, kernel_pb.MmmKernel()) + + +if __name__ == '__main__': + absltest.main() diff --git a/schema/processors/model_processor.py b/schema/processors/model_processor.py new file mode 100644 index 000000000..222b1e15f --- /dev/null +++ b/schema/processors/model_processor.py @@ -0,0 +1,412 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines common and base classes for processing trained Meridian model to an MMM schema.""" + +import abc +from collections.abc import Sequence +import dataclasses +import datetime +import functools +from typing import Generic, TypeVar + +from google.protobuf import message +from meridian import constants as c +from meridian.analysis import analyzer +from meridian.analysis import optimizer +from meridian.analysis import visualizer +from meridian.data import time_coordinates as tc +from meridian.model import model +from mmm.v1 import mmm_pb2 as pb +from mmm.v1.common import date_interval_pb2 +from schema.utils import time_record +from typing_extensions import override + + +__all__ = [ + 'ModelProcessor', + 'TrainedModel', + 'DatedSpec', + 'DatedSpecResolver', + 'OptimizationSpec', + 'ensure_trained_model', +] + + +class TrainedModel(abc.ABC): + """Encapsulates a trained MMM model.""" + + def __init__(self, mmm: model.Meridian): + """Initializes the TrainedModel with a Meridian model. + + Args: + mmm: A Meridian model that has been fitted (posterior samples drawn). + + Raises: + ValueError: If the model has not been fitted (posterior samples drawn). + """ + # Ideally, this could be encoded in the model type itself, and we won't need + # this extra runtime check. + if mmm.inference_data.prior is None or mmm.inference_data.posterior is None: # pytype: disable=attribute-error + raise ValueError('MMM model has not been fitted.') + self._mmm = mmm + + @property + def mmm(self) -> model.Meridian: + return self._mmm + + @property + def time_coordinates(self) -> tc.TimeCoordinates: + return self._mmm.input_data.time_coordinates + + @functools.cached_property + def internal_analyzer(self) -> analyzer.Analyzer: + """Returns an internal `Analyzer` bound to this trained model.""" + return analyzer.Analyzer(self.mmm) + + @functools.cached_property + def internal_optimizer(self) -> optimizer.BudgetOptimizer: + """Returns an internal `BudgetOptimizer` bound to this trained model.""" + return optimizer.BudgetOptimizer(self.mmm) + + @functools.cached_property + def internal_model_diagnostics(self) -> visualizer.ModelDiagnostics: + """Returns an internal `ModelDiagnostics` bound to this trained model.""" + return visualizer.ModelDiagnostics(self.mmm) + + +ModelType = model.Meridian | TrainedModel + + +def ensure_trained_model(model_input: ModelType) -> TrainedModel: + """Ensure the given model is a trained model, and wrap it in a TrainedModel.""" + if isinstance(model_input, TrainedModel): + return model_input + return TrainedModel(model_input) + + +class Spec(abc.ABC): + """Contains parameters needed for model-based analysis/optimization.""" + + @abc.abstractmethod + def validate(self): + """Checks whether each parameter in the Spec has a valid value. + + Raises: + ValueError: If any parameter in the Spec has an invalid value. + """ + + def __post_init__(self): + self.validate() + + +@dataclasses.dataclass(frozen=True) +class DatedSpec(Spec): + """A spec with a `[start_date, end_date)` closed-open date range semantic. + + Attrs: + start_date: The start date of the analysis/optimization. If left as `None`, + then this will eventually resolve to a model's first time coordinate. + end_date: The end date of the analysis/optimization. If left as `None`, then + this will eventually resolve to a model's last time coordinate. When + specified, this end date is exclusive. + date_interval_tag: An optional tag that identifies the date interval. + """ + + start_date: datetime.date | None = None + end_date: datetime.date | None = None + date_interval_tag: str = '' + + @override + def validate(self): + """Overrides the Spec.validate() method to check that dates are valid.""" + if ( + self.start_date is not None + and self.end_date is not None + and self.start_date > self.end_date + ): + raise ValueError('Start date must be before end date.') + + def resolver( + self, time_coordinates: tc.TimeCoordinates + ) -> 'DatedSpecResolver': + """Returns a date resolver for this spec, with the given Meridian model.""" + return DatedSpecResolver(self, time_coordinates) + + +class DatedSpecResolver: + """Resolves date parameters in specs based on a model's time coordinates.""" + + def __init__(self, spec: DatedSpec, time_coordinates: tc.TimeCoordinates): + self._spec = spec + self._time_coordinates = time_coordinates + + @property + def _interval_days(self) -> int: + return self._time_coordinates.interval_days + + @property + def time_coordinates(self) -> tc.TimeCoordinates: + return self._time_coordinates + + def to_closed_date_interval_tuple( + self, + ) -> tuple[str | None, str | None]: + """Transforms given spec into a closed `[start, end]` date interval tuple. + + For each of the bookends in the tuple, `None` value indicates a time + coordinate default (first or last time coordinate, respectively). + + Returns: + A **closed** `[start, end]` date interval tuple. + """ + start, end = (None, None) + + if self._spec.start_date is not None: + start = self._spec.start_date.strftime(c.DATE_FORMAT) + if self._spec.end_date is not None: + inclusive_end_date = self._spec.end_date - datetime.timedelta( + days=self._interval_days + ) + end = inclusive_end_date.strftime(c.DATE_FORMAT) + + return (start, end) + + def resolve_to_enumerated_selected_times(self) -> list[str] | None: + """Resolves the given spec into an enumerated list of time coordinates. + + Returns: + An enumerated list of time coordinates, or None (semantic "All") if the + bound spec is also None. + """ + start, end = self.to_closed_date_interval_tuple() + expanded = self._time_coordinates.expand_selected_time_dims( + start_date=start, end_date=end + ) + if expanded is None: + return None + return [date.strftime(c.DATE_FORMAT) for date in expanded] + + def resolve_to_bool_selected_times(self) -> list[bool] | None: + """Resolves the given spec into a list of booleans indicating selected times. + + Returns: + A list of booleans indicating selected times, or None (semantic "All") if + the bound spec is also None. + """ + selected_times = self.resolve_to_enumerated_selected_times() + if selected_times is None: + return None + return [ + time in selected_times for time in self._time_coordinates.all_dates_str + ] + + def collapse_to_date_interval_proto(self) -> date_interval_pb2.DateInterval: + """Collapses the given spec into a `DateInterval` proto. + + If the spec's date range is unbounded, then the DateInterval proto will have + the semantic "All", and we resolve it by consulting the time coordinates of + the model bound to this resolver. + + Note that the exclusive end date semantic will be preserved in the returned + proto. + + Returns: + A `DateInterval` proto the represents the date interval specified by the + spec. + """ + selected_times = self.resolve_to_enumerated_selected_times() + if selected_times is None: + start_date = self._time_coordinates.all_dates[0] + end_date = self._time_coordinates.all_dates[-1] + else: + normalized_selected_times = [ + tc.normalize_date(date) for date in selected_times + ] + start_date = normalized_selected_times[0] + end_date = normalized_selected_times[-1] + + # Adjust end_date to make it exclusive. + end_date += datetime.timedelta(days=self._interval_days) + + return time_record.create_date_interval_pb( + start_date, end_date, tag=self._spec.date_interval_tag + ) + + def transform_to_date_interval_protos( + self, + ) -> list[date_interval_pb2.DateInterval]: + """Transforms the given spec into `DateInterval` protos. + + If the spec's date range is unbounded, then the DateInterval proto will have + the semantic "All", and we resolve it by consulting the time coordinates of + the model bound to this resolver. + + Note that the exclusive end date semantic will be preserved in the returned + proto. + + Returns: + A list of `DateInterval` protos the represents the date intervals + specified by the spec. + """ + selected_times = self.resolve_to_enumerated_selected_times() + if selected_times is None: + times_list = self._time_coordinates.all_dates + else: + times_list = [tc.normalize_date(date) for date in selected_times] + + date_intervals = [] + for start_date in times_list: + date_interval = time_record.create_date_interval_pb( + start_date=start_date, + end_date=start_date + datetime.timedelta(days=self._interval_days), + tag=self._spec.date_interval_tag, + ) + date_intervals.append(date_interval) + + return date_intervals + + def resolve_to_date_interval_open_end( + self, + ) -> tuple[datetime.date, datetime.date]: + """Resolves given spec into an open-ended `[start, end)` date interval.""" + start = self._spec.start_date or self._time_coordinates.all_dates[0] + end = self._spec.end_date + if end is None: + end = self._time_coordinates.all_dates[-1] + # Adjust `end` to make it exclusive, but only if we pulled it from the + # time coordinates. + end += datetime.timedelta(days=self._interval_days) + return (start, end) + + def resolve_to_date_interval_proto(self) -> date_interval_pb2.DateInterval: + """Resolves the given spec into a fully specified `DateInterval` proto. + + If either `start_date` or `end_date` is None in the bound spec, then we + resolve it by consulting the time coordinates of the model bound to this + resolver. They are resolved to the first and last time coordinates (plus + interval length), respectively. + + Note that the exclusive end date semantic will be preserved in the returned + proto. + + Returns: + A resolved `DateInterval` proto the represents the date interval specified + by the bound spec. + """ + start_date, end_date = self.resolve_to_date_interval_open_end() + return time_record.create_date_interval_pb( + start_date, end_date, tag=self._spec.date_interval_tag + ) + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class OptimizationSpec(DatedSpec): + """A dated spec for optimization. + + Attrs: + optimization_name: The name of the optimization in this spec. + grid_name: The name of the optimization grid. + group_id: An optional group ID for linking related optimizations. + confidence_level: The threshold for computing confidence intervals. Defaults + to 0.9. Must be a number between 0 and 1. + """ + + optimization_name: str + grid_name: str + group_id: str | None = None + confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL + + @override + def validate(self): + """Check optimization parameters are valid.""" + super().validate() + + if not self.optimization_name or self.optimization_name.isspace(): + raise ValueError('Optimization name must not be empty or blank.') + + if not self.grid_name or self.grid_name.isspace(): + raise ValueError('Grid name must not be empty or blank.') + + if self.confidence_level < 0 or self.confidence_level > 1: + raise ValueError('Confidence level must be between 0 and 1.') + + +S = TypeVar('S', bound=Spec) +M = TypeVar('M', bound=message.Message) + + +class ModelProcessor(abc.ABC, Generic[S, M]): + """Performs model-based analysis or optimization.""" + + @classmethod + @abc.abstractmethod + def spec_type(cls) -> type[S]: + """Returns the concrete Spec type that this ModelProcessor operates on.""" + raise NotImplementedError() + + @classmethod + @abc.abstractmethod + def output_type(cls) -> type[M]: + """Returns the concrete output type that this ModelProcessor produces.""" + raise NotImplementedError() + + @abc.abstractmethod + def execute(self, specs: Sequence[S]) -> M: + """Runs an analysis/optimization on the model using the given specs. + + Args: + specs: Sequence of Specs containing parameters needed for the + analysis/optimization. The specs must all be of the same type as + `self.spec_type()` for this processor + + Returns: + A proto containing the results of the analysis/optimization. + """ + raise NotImplementedError() + + @abc.abstractmethod + def _set_output(self, output: pb.Mmm, result: M): + """Sets the output field in the given `MmmOutput` proto. + + A model consumer that orchestrated this processor will indirectly call this + method (via `__call__`) to attach the output of `execute()` (a + processor-defined message `M`) into a partially built `MmmOutput` proto that + the model consumer manages. + + Args: + output: The container output proto to which the given result message + should be attached. + result: An output of `execute()`. + """ + raise NotImplementedError() + + def __call__(self, specs: Sequence[S], output: pb.Mmm): + """Runs an analysis/optimization on the model using the given specs. + + This also sets the appropriate output field in the given MmmOutput proto. + + Args: + specs: Sequence of Specs containing parameters needed for the + analysis/optimization. The specs must all be of the same type as + `self.spec_type()` for this processor + output: The output proto to which the results of the analysis/optimization + should be attached. + + Raises: + ValueError: If any spec is not of the same type as `self.spec_type()`. + """ + if not all([isinstance(spec, self.spec_type()) for spec in specs]): + raise ValueError('Not all specs are of type %s' % self.spec_type()) + self._set_output(output, self.execute(specs)) diff --git a/schema/processors/reach_frequency_optimization_processor.py b/schema/processors/reach_frequency_optimization_processor.py new file mode 100644 index 000000000..cb9eb15d6 --- /dev/null +++ b/schema/processors/reach_frequency_optimization_processor.py @@ -0,0 +1,584 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines a processor for reach and frequency optimization inference on a Meridian model. + +This module provides the `ReachFrequencyOptimizationProcessor`, which optimizes +the average frequency for reach and frequency (R&F) media channels in a trained +Meridian model to maximize ROI. + +The processor takes a trained model and a `ReachFrequencyOptimizationSpec` +object. The spec defines the constraints for the optimization, such as the +minimum and maximum average frequency to consider for each channel. + +Key Features: + +- Optimizes average frequency for all R&F channels simultaneously. +- Allows setting minimum and maximum frequency constraints. +- Generates detailed results, including the optimal average frequency for + each channel, the expected outcomes at this optimal frequency, and + response curves showing KPI/Revenue as a function of spend. +- Outputs results in a structured protobuf format + (`ReachFrequencyOptimization`). + +Key Classes: + +- `ReachFrequencyOptimizationSpec`: Dataclass to specify optimization + parameters and constraints. +- `ReachFrequencyOptimizationProcessor`: The main processor class to execute + the R&F optimization. + +Example Usage: + +```python +from schema.processors import reach_frequency_optimization_processor +from schema.processors import common +from schema.processors import model_processor +import datetime + +# Assuming 'mmm' is a trained Meridian model object with R&F channels +trained_model = model_processor.TrainedModel(mmm) + +spec = reach_frequency_optimization_processor.ReachFrequencyOptimizationSpec( + optimization_name="rf_optimize_q1", + start_date=datetime.date(2023, 1, 1), + end_date=datetime.date(2023, 4, 1), + min_frequency=1.0, + max_frequency=10.0, # Optional, defaults to model's max frequency + kpi_type=common.KpiType.REVENUE, +) + +processor = ( + reach_frequency_optimization_processor.ReachFrequencyOptimizationProcessor( + trained_model + ) +) +# result is a rf_pb.ReachFrequencyOptimization proto +result = processor.execute([spec]) + +print(f"R&F Optimization results for {spec.optimization_name}:") +# Access results from the proto, e.g.: +# result.results[0].optimized_channel_frequencies +# result.results[0].optimized_marketing_analysis +# result.results[0].frequency_outcome_grid +``` + +Note: You can provide the processor with multiple specs. This would result in +a `ReachFrequencyOptimization` output with multiple results therein. +""" + +from collections.abc import Sequence +import dataclasses + +from meridian import backend +from meridian import constants +from mmm.v1 import mmm_pb2 as pb +from mmm.v1.common import kpi_type_pb2 as kpi_type_pb +from mmm.v1.marketing import marketing_data_pb2 +from mmm.v1.marketing.analysis import marketing_analysis_pb2 as analysis_pb +from mmm.v1.marketing.analysis import media_analysis_pb2 as media_analysis_pb +from mmm.v1.marketing.analysis import outcome_pb2 as outcome_pb +from mmm.v1.marketing.analysis import response_curve_pb2 +from mmm.v1.marketing.optimization import constraints_pb2 as constraints_pb +from mmm.v1.marketing.optimization import reach_frequency_optimization_pb2 as rf_pb +from schema.processors import common +from schema.processors import model_processor +from schema.utils import time_record +import numpy as np +import xarray as xr + + +__all__ = [ + "ReachFrequencyOptimizationSpec", + "ReachFrequencyOptimizationProcessor", +] + + +_STEP_SIZE_DECIMAL_PRECISION = 1 +_STEP_SIZE = _STEP_SIZE_DECIMAL_PRECISION / 10 +_TOL = 1e-6 + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class ReachFrequencyOptimizationSpec(model_processor.OptimizationSpec): + """Spec dataclass for marketing reach and frequency optimization processor. + + A frequency grid is generated using the range `[rounded_min_frequency, + rounded_max_frequency]` and a step size of `STEP_SIZE=0.1`. + `rounded_min_frequency` and `rounded_max_frequency` are rounded to the + nearest multiple of `STEP_SIZE`. + + This spec is used both as user input to inform the R&F optimization processor + of its constraints and parameters, as well as an output structure that is + serializable to a `ReachFrequencyOptimizationSpec` proto. The latter serves as + a metadata embedded in a `ReachFrequencyOptimizationResult`. The output spec + in the proto reflects the actual numbers used to generate the reach and + frequency optimization result. + + Attributes: + min_frequency: The minimum frequency constraint for each channel. Must be + greater than or equal to `1.0`. Defaults to `1.0`. + max_frequency: The maximum frequency constraint for each channel. Must be + greater than min_frequency. Defaults to None. If this value is set to + None, the model's max frequency will be used. + rf_channels: The R&F media channels in the model. When resolved with a + model, the model's R&F channels will be present here. Ignored when used as + input. + kpi_type: A `common.KpiType` enum denoting whether the optimized KPI is of a + `'revenue'` or `'non-revenue'` type. + """ + + min_frequency: float = 1.0 + max_frequency: float | None = None + rf_channels: Sequence[str] = dataclasses.field(default_factory=list) + kpi_type: common.KpiType = common.KpiType.REVENUE + + @property + def selected_times(self) -> tuple[str | None, str | None] | None: + """The start and end dates, as a tuple of date strings.""" + start, end = (None, None) + if self.start_date is not None: + start = self.start_date.strftime(constants.DATE_FORMAT) + if self.end_date is not None: + end = self.end_date.strftime(constants.DATE_FORMAT) + + if start or end: + return (start, end) + return None + + @property + def objective(self) -> common.TargetMetric: + """A Meridian budget optimization objective is always ROI.""" + return common.TargetMetric.ROI + + def validate(self): + super().validate() + if self.min_frequency < 0: + raise ValueError("Min frequency must be non-negative.") + if ( + self.max_frequency is not None + and self.max_frequency < self.min_frequency + ): + raise ValueError("Max frequency must be greater than min frequency.") + + def to_proto(self) -> rf_pb.ReachFrequencyOptimizationSpec: + # When invoked as an output proto, the spec should have been fully resolved. + if self.start_date is None or self.end_date is None: + raise ValueError( + "Start and end dates must be resolved before this spec can be" + " serialized." + ) + + return rf_pb.ReachFrequencyOptimizationSpec( + date_interval=time_record.create_date_interval_pb( + self.start_date, self.end_date, tag=self.date_interval_tag + ), + rf_channel_constraints=[ + rf_pb.RfChannelConstraint( + channel_name=channel, + frequency_constraint=constraints_pb.FrequencyConstraint( + min_frequency=self.min_frequency, + max_frequency=self.max_frequency, + ), + ) + for channel in self.rf_channels + ], + objective=self.objective.value, + kpi_type=( + kpi_type_pb.KpiType.REVENUE + if self.kpi_type == common.KpiType.REVENUE + else kpi_type_pb.KpiType.NON_REVENUE + ), + ) + + +class ReachFrequencyOptimizationProcessor( + model_processor.ModelProcessor[ + ReachFrequencyOptimizationSpec, rf_pb.ReachFrequencyOptimization + ], +): + """A Processor for marketing reach and frequency optimization.""" + + def __init__( + self, + trained_model: model_processor.ModelType, + ): + trained_model = model_processor.ensure_trained_model(trained_model) + self._internal_analyzer = trained_model.internal_analyzer + self._meridian = trained_model.mmm + + if trained_model.mmm.input_data.rf_channel is None: + raise ValueError("RF channels must be set in the model.") + + self._all_rf_channels = trained_model.mmm.input_data.rf_channel.data + + @classmethod + def spec_type(cls) -> type[ReachFrequencyOptimizationSpec]: + return ReachFrequencyOptimizationSpec + + @classmethod + def output_type(cls) -> type[rf_pb.ReachFrequencyOptimization]: + return rf_pb.ReachFrequencyOptimization + + def _to_target_precision(self, value: float) -> float: + return round(value, _STEP_SIZE_DECIMAL_PRECISION) + + def _set_output( + self, output: pb.Mmm, result: rf_pb.ReachFrequencyOptimization + ): + output.marketing_optimization.reach_frequency_optimization.CopyFrom(result) + + def execute( + self, specs: Sequence[ReachFrequencyOptimizationSpec] + ) -> rf_pb.ReachFrequencyOptimization: + output = rf_pb.ReachFrequencyOptimization() + + group_ids = [spec.group_id for spec in specs if spec.group_id] + if len(set(group_ids)) != len(group_ids): + raise ValueError( + "Specified group_id must be unique among the given group of specs." + ) + + for spec in specs: + selected_times = spec.resolver( + self._meridian.input_data.time_coordinates + ).resolve_to_enumerated_selected_times() + + grid_min_freq = self._to_target_precision(spec.min_frequency) + # If the max frequency is not set, use the model's max frequency. + grid_max_freq = self._to_target_precision( + spec.max_frequency or np.max(self._meridian.rf_tensors.frequency) + ) + grid = [ + self._to_target_precision(f) + for f in np.arange(grid_min_freq, grid_max_freq + _TOL, _STEP_SIZE) + ] + + # Note that the internal analyzer, like the budget optimizer, maximizes + # non-revenue KPI if input data is of non-revenue and the user selects + # `use_kpi=True`. Otherwise, it maximizes revenue KPI. + optimal_frequency = self._internal_analyzer.optimal_freq( + selected_times=selected_times, + confidence_level=spec.confidence_level, + freq_grid=grid, + use_kpi=(spec.kpi_type == common.KpiType.NON_REVENUE), + ) + response_curve = self._internal_analyzer.response_curves( + confidence_level=spec.confidence_level, + selected_times=selected_times, + by_reach=False, + use_kpi=(spec.kpi_type == common.KpiType.NON_REVENUE), + use_optimal_frequency=True, + ) + + spend_data = self._compute_spend_data(selected_times=selected_times) + + # Obtain the output spec. + start, end = spec.resolver( + self._meridian.input_data.time_coordinates + ).resolve_to_date_interval_open_end() + + # Copy the current spec, and resolve its date interval as well as model- + # dependent parameters. + output_spec = dataclasses.replace( + spec, + rf_channels=self._all_rf_channels, + min_frequency=grid_min_freq, + max_frequency=grid_max_freq, + start_date=start, + end_date=end, + ) + + output.results.append( + self._to_reach_frequency_optimization_result( + output_spec, + optimal_frequency, + response_curve, + spend_data, + ) + ) + return output + + def _compute_spend_data( + self, selected_times: list[str] | None = None + ) -> xr.Dataset: + aggregated_spends = self._internal_analyzer.get_historical_spend( + selected_times + ) + aggregated_rf_spend = aggregated_spends.sel( + {constants.CHANNEL: self._all_rf_channels} + ).data + total_spend = np.sum(aggregated_spends.data) + pct_of_spend = 100.0 * aggregated_rf_spend / total_spend + + xr_dims = (constants.RF_CHANNEL,) + xr_coords = { + constants.RF_CHANNEL: ( + [constants.RF_CHANNEL], + list(self._all_rf_channels), + ), + } + xr_data_vars = { + constants.SPEND: (xr_dims, aggregated_rf_spend), + constants.PCT_OF_SPEND: (xr_dims, pct_of_spend), + } + + return xr.Dataset( + data_vars=xr_data_vars, + coords=xr_coords, + ) + + def _to_reach_frequency_optimization_result( + self, + spec: ReachFrequencyOptimizationSpec, + optimal_frequency: xr.Dataset, + response_curve: xr.Dataset, + spend_data: xr.Dataset, + ) -> rf_pb.ReachFrequencyOptimizationResult: + """Converts given optimal frequency dataset to protobuf form.""" + result = rf_pb.ReachFrequencyOptimizationResult( + name=spec.optimization_name, + spec=spec.to_proto(), + optimized_channel_frequencies=_create_optimized_channel_frequencies( + optimal_frequency + ), + optimized_marketing_analysis=self._to_marketing_analysis( + spec, + optimal_frequency, + response_curve, + spend_data, + ), + frequency_outcome_grid=self._create_frequency_outcome_grid( + optimal_frequency, + spec, + ), + ) + if spec.group_id: + result.group_id = spec.group_id + return result + + def _to_marketing_analysis( + self, + spec: ReachFrequencyOptimizationSpec, + optimal_frequency: xr.Dataset, + response_curve: xr.Dataset, + spend_data: xr.Dataset, + ) -> analysis_pb.MarketingAnalysis: + """Converts an optimal frequency dataset to a `MarketingAnalysis` proto.""" + # `spec` should have been resolved with concrete date interval parameters. + assert spec.start_date is not None and spec.end_date is not None + + optimized_marketing_analysis = analysis_pb.MarketingAnalysis( + date_interval=time_record.create_date_interval_pb( + start_date=spec.start_date, + end_date=spec.end_date, + ), + ) + + # Create a per-channel MediaAnalysis. + channels = optimal_frequency.coords[constants.RF_CHANNEL].data + for channel in channels: + channel_optimal_frequency = optimal_frequency.sel(rf_channel=channel) + channel_spend_data = spend_data.sel(rf_channel=channel) + + # TODO Add non-media analyses. + channel_media_analysis = media_analysis_pb.MediaAnalysis( + channel_name=channel, + response_curve=_compute_response_curve( + response_curve, + channel, + ), + spend_info=media_analysis_pb.SpendInfo( + spend=channel_spend_data[constants.SPEND].data.item(), + spend_share=( + channel_spend_data[constants.PCT_OF_SPEND].data.item() + ), + ), + ) + + # Output one outcome per channel: either revenue or non-revenue. + channel_media_analysis.media_outcomes.append( + _to_outcome( + channel_optimal_frequency, + is_revenue_kpi=optimal_frequency.attrs[constants.IS_REVENUE_KPI], + ) + ) + + optimized_marketing_analysis.media_analyses.append(channel_media_analysis) + + return optimized_marketing_analysis + + def _create_frequency_outcome_grid( + self, + optimal_frequency_dataset: xr.Dataset, + spec: ReachFrequencyOptimizationSpec, + ) -> rf_pb.FrequencyOutcomeGrid: + """Creates a FrequencyOutcomeGrid proto.""" + channel_cells = [] + frequencies = optimal_frequency_dataset.coords[constants.FREQUENCY].data + channels = optimal_frequency_dataset.coords[constants.RF_CHANNEL].data + input_tensor_dims = "gtc" + output_tensor_dims = "c" + + for channel in channels: + cells = [] + for frequency in frequencies: + new_frequency = ( + backend.ones_like(self._meridian.rf_tensors.frequency) * frequency + ) + new_reach = ( + self._meridian.rf_tensors.frequency + * self._meridian.rf_tensors.reach + / new_frequency + ) + channel_mask = [c == channel for c in channels] + filtered_reach = backend.boolean_mask(new_reach, channel_mask, axis=2) + aggregated_reach = backend.einsum( + f"{input_tensor_dims}->...{output_tensor_dims}", filtered_reach + ) + reach = aggregated_reach.numpy()[-1] + + metric_data_array = optimal_frequency_dataset[constants.ROI].sel( + frequency=frequency, rf_channel=channel + ) + outcome = common.to_estimate(metric_data_array, spec.confidence_level) + + cell = rf_pb.FrequencyOutcomeGrid.Cell( + outcome=outcome, + reach_frequency=marketing_data_pb2.ReachFrequency( + reach=int(reach), + average_frequency=frequency, + ), + ) + cells.append(cell) + + channel_cell = rf_pb.FrequencyOutcomeGrid.ChannelCells( + channel_name=channel, + cells=cells, + ) + channel_cells.append(channel_cell) + + return rf_pb.FrequencyOutcomeGrid( + name=spec.grid_name, + frequency_step_size=_STEP_SIZE, + channel_cells=channel_cells, + ) + + +def _create_optimized_channel_frequencies( + optimal_frequency_dataset: xr.Dataset, +) -> list[rf_pb.OptimizedChannelFrequency]: + """Creates an OptimizedChannelFrequency proto for each channel in the dataset.""" + optimal_frequency_protos = [] + optimal_frequency = optimal_frequency_dataset[constants.OPTIMAL_FREQUENCY] + channels = optimal_frequency.coords[constants.RF_CHANNEL].data + + for channel in channels: + optimal_frequency_protos.append( + rf_pb.OptimizedChannelFrequency( + channel_name=channel, + optimal_average_frequency=optimal_frequency.sel( + rf_channel=channel + ).item(), + ) + ) + return optimal_frequency_protos + + +def _to_outcome( + channel_optimal_frequency: xr.Dataset, + is_revenue_kpi: bool, +) -> outcome_pb.Outcome: + """Returns an `Outcome` value for a given channel's optimized media analysis. + + Args: + channel_optimal_frequency: A channel-selected dataset from + `Analyzer.optimal_freq()`. + is_revenue_kpi: Whether the KPI is revenue-based. + """ + confidence_level = channel_optimal_frequency.attrs[constants.CONFIDENCE_LEVEL] + return outcome_pb.Outcome( + kpi_type=( + kpi_type_pb.REVENUE if is_revenue_kpi else kpi_type_pb.NON_REVENUE + ), + roi=common.to_estimate( + channel_optimal_frequency.optimized_roi, confidence_level + ), + marginal_roi=common.to_estimate( + channel_optimal_frequency.optimized_mroi_by_frequency, + confidence_level, + ), + cost_per_contribution=common.to_estimate( + channel_optimal_frequency.optimized_cpik, + confidence_level=confidence_level, + ), + contribution=outcome_pb.Contribution( + value=common.to_estimate( + channel_optimal_frequency.optimized_incremental_outcome, + confidence_level, + ), + ), + effectiveness=outcome_pb.Effectiveness( + media_unit=constants.IMPRESSIONS, + value=common.to_estimate( + channel_optimal_frequency.optimized_effectiveness, + confidence_level, + ), + ), + ) + + +def _compute_response_curve( + response_curve_dataset: xr.Dataset, + channel_name: str, +) -> response_curve_pb2.ResponseCurve: + """Returns a ResponseCurve proto for the given channel. + + Args: + response_curve_dataset: A dataset containing the data needed to generate a + response curve. + channel_name: The name of the channel to analyze. + """ + + spend_multiplier_list = response_curve_dataset.coords[ + constants.SPEND_MULTIPLIER + ].data + response_points: list[response_curve_pb2.ResponsePoint] = [] + + for spend_multiplier in spend_multiplier_list: + spend = ( + response_curve_dataset[constants.SPEND] + .sel(spend_multiplier=spend_multiplier, channel=channel_name) + .data.item() + ) + incremental_outcome = ( + response_curve_dataset[constants.INCREMENTAL_OUTCOME] + .sel( + spend_multiplier=spend_multiplier, + channel=channel_name, + metric=constants.MEAN, + ) + .data.item() + ) + + response_point = response_curve_pb2.ResponsePoint( + input_value=spend, + incremental_kpi=incremental_outcome, + ) + response_points.append(response_point) + + return response_curve_pb2.ResponseCurve( + input_name=constants.SPEND, + response_points=response_points, + ) diff --git a/schema/serde/__init__.py b/schema/serde/__init__.py new file mode 100644 index 000000000..54cd0065c --- /dev/null +++ b/schema/serde/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A serialization and deserialization library for Meridian models. + +For entry points API, see `meridian_serde` module docs. +""" + +from schema.serde import constants +from schema.serde import distribution +from schema.serde import hyperparameters +from schema.serde import inference_data +from schema.serde import meridian_serde +from schema.serde import serde diff --git a/schema/serde/constants.py b/schema/serde/constants.py new file mode 100644 index 000000000..58426ace8 --- /dev/null +++ b/schema/serde/constants.py @@ -0,0 +1,48 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants shared across the Meridian serde library.""" + +# Constants for hyperparameters protobuf structure +BASELINE_GEO_ONEOF = 'baseline_geo_oneof' +BASELINE_GEO_INT = 'baseline_geo_int' +BASELINE_GEO_STRING = 'baseline_geo_string' +CONTROL_POPULATION_SCALING_ID = 'control_population_scaling_id' +HOLDOUT_ID = 'holdout_id' +NON_MEDIA_POPULATION_SCALING_ID = 'non_media_population_scaling_id' +ADSTOCK_DECAY_SPEC = 'adstock_decay_spec' +GLOBAL_ADSTOCK_DECAY = 'global_adstock_decay' +ADSTOCK_DECAY_BY_CHANNEL = 'adstock_decay_by_channel' +DEFAULT_DECAY = 'geometric' + +# Constants for marketing data protobuf structure +GEO_INFO = 'geo_info' +METADATA = 'metadata' +REACH_FREQUENCY = 'reach_frequency' + +# Constants for distribution protobuf structure +DISTRIBUTION_TYPE = 'distribution_type' +BATCH_BROADCAST_DISTRIBUTION = 'batch_broadcast' +DETERMINISTIC_DISTRIBUTION = 'deterministic' +HALF_NORMAL_DISTRIBUTION = 'half_normal' +LOG_NORMAL_DISTRIBUTION = 'log_normal' +NORMAL_DISTRIBUTION = 'normal' +TRANSFORMED_DISTRIBUTION = 'transformed' +TRUNCATED_NORMAL_DISTRIBUTION = 'truncated_normal' +UNIFORM_DISTRIBUTION = 'uniform' +BETA_DISTRIBUTION = 'beta' +BIJECTOR_TYPE = 'bijector_type' +SHIFT_BIJECTOR = 'shift' +SCALE_BIJECTOR = 'scale' +RECIPROCAL_BIJECTOR = 'reciprocal' diff --git a/schema/serde/distribution.py b/schema/serde/distribution.py new file mode 100644 index 000000000..5735b5056 --- /dev/null +++ b/schema/serde/distribution.py @@ -0,0 +1,548 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Serialization and deserialization of `Distribution` objects for priors.""" + +from __future__ import annotations + +import hashlib +import inspect +import types +from typing import Any, Callable, Sequence, TypeVar +import warnings + +from meridian import backend +from meridian import constants +from meridian.model import prior_distribution as pd +from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb +from schema.serde import constants as sc +from schema.serde import serde + +from tensorflow.core.framework import tensor_shape_pb2 # pylint: disable=g-direct-tensorflow-import + +FunctionRegistry = dict[str, Callable[..., Any]] + + +MeridianPriorDistributions = ( + meridian_pb.PriorTfpDistributions +) + + +# TODO: b/436637084 - Delete enumerated schema. +class DistributionSerde( + serde.Serde[MeridianPriorDistributions, pd.PriorDistribution] +): + """Serializes and deserializes a Meridian prior distributions container into a `Distribution` proto.""" + + def __init__(self, function_registry: FunctionRegistry | None = None): + """Initializes a `DistributionSerde` instance. + + Args: + function_registry: A lookup table containing custom functions used by + various `backend.tfd` classes. + """ + self._function_registry = function_registry + + @property + def function_registry(self) -> FunctionRegistry | None: + return self._function_registry + + def serialize( + self, obj: pd.PriorDistribution + ) -> meridian_pb.PriorTfpDistributions: + """Serializes the given Meridian priors container into a `MeridianPriorDistributions` proto.""" + proto = meridian_pb.PriorTfpDistributions() + for param in constants.ALL_PRIOR_DISTRIBUTION_PARAMETERS: + if not hasattr(obj, param): + continue + # TODO: b/436636530 - Implement serialization for generic schema. + getattr(proto, param).CopyFrom( + self._to_distribution_proto(getattr(obj, param)) + ) + if self.function_registry is not None: + hashed_function_registry = hash_function_registry(self.function_registry) + proto.function_registry.update(hashed_function_registry) + return proto + + def deserialize( + self, + serialized: MeridianPriorDistributions, + serialized_version: str = "", + force_deserialization: bool = False, + ) -> pd.PriorDistribution: + """Deserializes the `PriorTfpDistributions` proto. + + WARNING: If any custom functions in the function registry are modified after + serialization, the deserialized model can differ from the original model, as + the original function's behavior is no longer guaranteed. This will result + in an error during deserialization. + + For users who are intentionally changing functions and are confident that + the changes will not affect the deserialized model, you can bypass safety + mechanisms to force deserialization. See example: + + Args: + serialized: A serialized `PriorDistributions` object. + serialized_version: The version of the serialized Meridian model. + force_deserialization: If True, bypasses the safety check that validates + whether functions within `function_registry` have changed after + serialization. Use with caution. This should only be used if you have + intentionally modified a custom function and are confident that the + changes will not affect the deserialized model. A safer alternative is + to first deserialize the model with the original functions and then + serialize it with the new ones. + + Returns: + A deserialized `PriorDistribution` object. + """ + kwargs = {} + for param in constants.ALL_PRIOR_DISTRIBUTION_PARAMETERS: + if not hasattr(serialized, param): + continue + # A parameter may be unspecified in a serialized proto message because: + # (1) It is left unset for Meridian to set its default value. + # (2) The message was created from a previous Meridian version after + # introducing a new parameter. + if not serialized.HasField(param): + continue + param_name = getattr(serialized, param) + if isinstance(serialized, meridian_pb.PriorTfpDistributions): + if force_deserialization: + warnings.warn( + "You're attempting to deserialize a model while ignoring changes" + " to custom functions. This is a risky operation that can" + " potentially lead to a deserialized model that behaves" + " differently from the original, resulting in unexpected behavior" + " or model failure. We strongly recommend a safer two-step" + " process: deserialize the model using the original function" + " registry and reserialize the model using the updated registry." + " Please proceed with caution." + ) + else: + _validate_function_registry(self.function_registry, serialized) + kwargs[param] = self._from_distribution_proto(param_name) + # copybara: strip_begin(legacy proto) + elif isinstance(serialized, meridian_pb.PriorDistributions): + kwargs[param] = _from_legacy_distribution_proto(param_name) + # copybara: strip_end + return pd.PriorDistribution(**kwargs) + + def _to_distribution_proto( + self, + dist: backend.tfd.Distribution, + ) -> meridian_pb.TfpDistribution: + """Converts a TensorFlow `Distribution` object to a `TfpDistribution` proto.""" + dist_name = type(dist).__name__ + dist_class = getattr(backend.tfd, dist_name) + return meridian_pb.TfpDistribution( + distribution_type=dist_name, + parameters={ + name: self._to_parameter_value_proto(name, value, dist_class) + for name, value in dist.parameters.items() + }, + ) + + def _to_bijector_proto( + self, + bijector: backend.bijectors.Bijector, + ) -> meridian_pb.TfpBijector: + """Converts a TensorFlow `Bijector` object to a `TfpBijector` proto.""" + bij_name = type(bijector).__name__ + bij_class = getattr(backend.bijectors, bij_name) + return meridian_pb.TfpBijector( + bijector_type=bij_name, + parameters={ + name: self._to_parameter_value_proto(name, value, bij_class) + for name, value in bijector.parameters.items() + }, + ) + + def _to_parameter_value_proto( + self, + param_name: str, + value: Any, + dist: backend.tfd.Distribution | backend.bijectors.Bijector, + ) -> meridian_pb.TfpParameterValue: + """Converts a TensorFlow `Distribution` parameter value to a `TfpParameterValue` proto.""" + # Handle built-in types. + match value: + case float(): + return meridian_pb.TfpParameterValue(scalar_value=value) + case int(): + return meridian_pb.TfpParameterValue(int_value=value) + case bool(): + return meridian_pb.TfpParameterValue(bool_value=value) + case str(): + return meridian_pb.TfpParameterValue(string_value=value) + case None: + return meridian_pb.TfpParameterValue(none_value=True) + case list(): + value_generator = ( + self._to_parameter_value_proto(param_name, v, dist) for v in value + ) + return meridian_pb.TfpParameterValue( + list_value=meridian_pb.TfpParameterValue.List( + values=value_generator + ) + ) + case dict(): + dict_value = { + k: self._to_parameter_value_proto(param_name, v, dist) + for k, v in value.items() + } + return meridian_pb.TfpParameterValue( + dict_value=meridian_pb.TfpParameterValue.Dict(value_map=dict_value) + ) + case backend.Tensor(): + return meridian_pb.TfpParameterValue( + tensor_value=backend.make_tensor_proto(value) + ) + case backend.tfd.Distribution(): + return meridian_pb.TfpParameterValue( + distribution_value=self._to_distribution_proto(value) + ) + case backend.bijectors.Bijector(): + return meridian_pb.TfpParameterValue( + bijector_value=self._to_bijector_proto(value) + ) + case backend.tfd.ReparameterizationType(): + fully_reparameterized = value == backend.tfd.FULLY_REPARAMETERIZED + return meridian_pb.TfpParameterValue( + fully_reparameterized=fully_reparameterized + ) + case types.FunctionType(): + # Check for default value + signature = inspect.signature(dist.__init__) + param = signature.parameters[param_name] + if param.default and param.default is value: + return meridian_pb.TfpParameterValue( + function_param=meridian_pb.TfpParameterValue.FunctionParam( + uses_default=True + ) + ) + + # Check against registry. + registry = self.function_registry + if registry is not None: + for function_key, func in registry.items(): + if func is value: + return meridian_pb.TfpParameterValue( + function_param=meridian_pb.TfpParameterValue.FunctionParam( + function_key=function_key + ) + ) + + raise ValueError( + f"Custom function `{param_name}` detected for" + f" {type(dist).__name__}, but not found in registry. Please" + " add custom functions to registry when saving models." + ) + + # Handle unsupported types. + raise TypeError(f"Unsupported type: {type(value)}, {value}") + + def _from_distribution_proto( + self, + dist_proto: meridian_pb.TfpDistribution, + ) -> backend.tfd.Distribution: + """Converts a `Distribution` proto to a TensorFlow `Distribution` object.""" + dist_class_name = dist_proto.distribution_type + dist_class = getattr(backend.tfd, dist_class_name) + dist_parameters = dist_proto.parameters + input_parameters = { + k: self._unpack_tfp_parameters(k, v, dist_class) + for k, v in dist_parameters.items() + } + return dist_class(**input_parameters) + + def _from_bijector_proto( + self, + dist_proto: meridian_pb.TfpBijector, + ) -> backend.bijectors.Bijector: + """Converts a `Bijector` proto to a TensorFlow `Bijector` object.""" + dist_class_name = dist_proto.bijector_type + dist_class = getattr(backend.bijectors, dist_class_name) + dist_parameters = dist_proto.parameters + input_parameters = { + name: self._unpack_tfp_parameters(name, value, dist_class) + for name, value in dist_parameters.items() + } + + return dist_class(**input_parameters) + + def _unpack_tfp_parameters( + self, + param_name: str, + param_value: meridian_pb.TfpParameterValue, + dist_class: backend.tfd.Distribution, + ) -> Any: + """Unpacks a `TfpParameterValue` proto into a Python value.""" + match param_value.WhichOneof("value_type"): + # Handle built-in types. + case "scalar_value": + return param_value.scalar_value + case "int_value": + return param_value.int_value + case "bool_value": + return param_value.bool_value + case "string_value": + return param_value.string_value + case "none_value": + return None + case "list_value": + return [ + self._unpack_tfp_parameters(param_name, v, dist_class) + for v in param_value.list_value.values + ] + case "dict_value": + items = param_value.dict_value.value_map.items() + return { + key: self._unpack_tfp_parameters(key, value, dist_class) + for key, value in items + } + + # Handle custom types. + case "tensor_value": + return backend.to_tensor(backend.make_ndarray(param_value.tensor_value)) + case "distribution_value": + return self._from_distribution_proto(param_value.distribution_value) + case "bijector_value": + return self._from_bijector_proto(param_value.bijector_value) + case "fully_reparameterized": + if param_value.fully_reparameterized: + return backend.tfd.FULLY_REPARAMETERIZED + else: + return backend.tfd.NOT_FULLY_REPARAMETERIZED + + # Handle functions. + case "function_param": + function_param = param_value.function_param + # Check against registry. + if function_param.HasField("function_key"): + registry = self.function_registry + if registry is not None and function_param.function_key in registry: + return registry.get(function_param.function_key) + # Check for default value. + if ( + function_param.HasField("uses_default") + and function_param.uses_default + ): + signature = inspect.signature(dist_class.__init__) + return signature.parameters[param_name].default + raise ValueError(f"No function found for {param_name}") + + # Handle unsupported types. + case _: + raise ValueError( + f"Unsupported TFP distribution parameter type: {type(param_value)}" + ) + + +def _validate_function_registry( + current_function_registry: FunctionRegistry | None, + serial: meridian_pb.PriorTfpDistributions, +): + """Validates whether the functions within the registry have changed since initial serialization.""" + stored_function_registry = getattr(serial, "function_registry") + if stored_function_registry and not current_function_registry: + raise ValueError( + "A function registry was detected on the saved model and not provided" + " when attempting to load." + ) + elif not stored_function_registry and current_function_registry: + warnings.warn( + "A function registry was detected when attempting to load the model," + " but not found on the serialized model. Custom functions will be" + " ignored." + ) + elif stored_function_registry and current_function_registry: + for key, value in current_function_registry.items(): + stored_hash = stored_function_registry[key] + evaluated_hash = get_hash(inspect.getsource(value)) + if stored_hash != evaluated_hash: + raise ValueError(f"Function registry hash mismatch for {key}.") + + +def hash_function_registry(registry: FunctionRegistry) -> dict[str, str]: + """Returns hashed function registry with keys mapped to hashed function code.""" + return { + key: get_hash(inspect.getsource(function)) + for key, function in registry.items() + } + + +def get_hash(value: str) -> str: + """Returns a hashed string.""" + encoded_string = value.encode("utf-8") + sha_256_hash = hashlib.sha256() + sha_256_hash.update(encoded_string) + return sha_256_hash.hexdigest() + + +# copybara: strip_begin + + +def _from_legacy_bijector_proto( + bijector_proto: meridian_pb.Distribution.Bijector, +) -> backend.bijectors.Bijector: + """Converts a `Bijector` proto to a `Bijector` object.""" + bijector_type_field = bijector_proto.WhichOneof(sc.BIJECTOR_TYPE) + match bijector_type_field: + case sc.SHIFT_BIJECTOR: + return backend.bijectors.Shift( + shift=_deserialize_sequence(bijector_proto.shift.shifts) + ) + case sc.SCALE_BIJECTOR: + return backend.bijectors.Scale( + scale=_deserialize_sequence(bijector_proto.scale.scales), + log_scale=_deserialize_sequence(bijector_proto.scale.log_scales), + ) + case sc.RECIPROCAL_BIJECTOR: + return backend.bijectors.Reciprocal() + case _: + raise ValueError( + f"Unsupported Bijector proto type: {bijector_type_field};" + f" Bijector proto:\n{bijector_proto}" + ) + + +def _from_legacy_distribution_proto( + dist_proto: meridian_pb.Distribution, +) -> backend.tfd.Distribution: + """Converts a `Distribution` proto to a `Distribution` object.""" + dist_type_field = dist_proto.WhichOneof(sc.DISTRIBUTION_TYPE) + match dist_type_field: + case sc.BATCH_BROADCAST_DISTRIBUTION: + return backend.tfd.BatchBroadcast( + name=dist_proto.name, + distribution=_from_legacy_distribution_proto( + dist_proto.batch_broadcast.distribution + ), + with_shape=_from_shape_proto(dist_proto.batch_broadcast.batch_shape), + ) + case sc.TRANSFORMED_DISTRIBUTION: + return backend.tfd.TransformedDistribution( + name=dist_proto.name, + distribution=_from_legacy_distribution_proto( + dist_proto.transformed.distribution + ), + bijector=_from_legacy_bijector_proto(dist_proto.transformed.bijector), + ) + case sc.DETERMINISTIC_DISTRIBUTION: + return backend.tfd.Deterministic( + name=dist_proto.name, + loc=_deserialize_sequence(dist_proto.deterministic.locs), + ) + case sc.HALF_NORMAL_DISTRIBUTION: + return backend.tfd.HalfNormal( + name=dist_proto.name, + scale=_deserialize_sequence(dist_proto.half_normal.scales), + ) + case sc.LOG_NORMAL_DISTRIBUTION: + return backend.tfd.LogNormal( + name=dist_proto.name, + loc=_deserialize_sequence(dist_proto.log_normal.locs), + scale=_deserialize_sequence(dist_proto.log_normal.scales), + ) + case sc.NORMAL_DISTRIBUTION: + return backend.tfd.Normal( + name=dist_proto.name, + loc=_deserialize_sequence(dist_proto.normal.locs), + scale=_deserialize_sequence(dist_proto.normal.scales), + ) + case sc.TRUNCATED_NORMAL_DISTRIBUTION: + if ( + hasattr(dist_proto.truncated_normal, "lows") + and dist_proto.truncated_normal.lows + ): + if dist_proto.truncated_normal.low: + _show_warning("low", "TruncatedNormal") + low = _deserialize_sequence(dist_proto.truncated_normal.lows) + else: + low = dist_proto.truncated_normal.low + + if ( + hasattr(dist_proto.truncated_normal, "highs") + and dist_proto.truncated_normal.highs + ): + if dist_proto.truncated_normal.high: + _show_warning("high", "TruncatedNormal") + high = _deserialize_sequence(dist_proto.truncated_normal.highs) + else: + high = dist_proto.truncated_normal.high + return backend.tfd.TruncatedNormal( + name=dist_proto.name, + loc=_deserialize_sequence(dist_proto.truncated_normal.locs), + scale=_deserialize_sequence(dist_proto.truncated_normal.scales), + low=low, + high=high, + ) + case sc.UNIFORM_DISTRIBUTION: + if hasattr(dist_proto.uniform, "lows") and dist_proto.uniform.lows: + if dist_proto.uniform.low: + _show_warning("low", "Uniform") + low = _deserialize_sequence(dist_proto.uniform.lows) + else: + low = dist_proto.uniform.low + + if hasattr(dist_proto.uniform, "highs") and dist_proto.uniform.highs: + if dist_proto.uniform.high: + _show_warning("high", "Uniform") + high = _deserialize_sequence(dist_proto.uniform.highs) + else: + high = dist_proto.uniform.high + + return backend.tfd.Uniform( + name=dist_proto.name, + low=low, + high=high, + ) + case sc.BETA_DISTRIBUTION: + return backend.tfd.Beta( + name=dist_proto.name, + concentration1=_deserialize_sequence(dist_proto.beta.alpha), + concentration0=_deserialize_sequence(dist_proto.beta.beta), + ) + case _: + raise ValueError( + f"Unsupported Distribution proto type: {dist_type_field};" + f" Distribution proto:\n{dist_proto}" + ) + + +def _show_warning(field_name: str, dist_name: str) -> None: + warnings.warn( + f"Both `{field_name}s` and `{field_name}` are specified in" + f" {dist_name} distribution proto. Prioritizing `{field_name}s` since" + f" `{field_name}` is deprecated.", + DeprecationWarning, + ) + + +def _from_shape_proto( + shape_proto: tensor_shape_pb2.TensorShapeProto, +) -> backend.TensorShapeInstance: + """Converts a `TensorShapeProto` to a `TensorShape`.""" + return backend.TensorShape([dim.size for dim in shape_proto.dim]) + + +T = TypeVar("T") + + +def _deserialize_sequence(args: Sequence[T]) -> T | Sequence[T] | None: + if not args: + return None + return args[0] if len(args) == 1 else list(args) + +# copybara: strip_end diff --git a/schema/serde/hyperparameters.py b/schema/serde/hyperparameters.py new file mode 100644 index 000000000..aca543003 --- /dev/null +++ b/schema/serde/hyperparameters.py @@ -0,0 +1,310 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Serde for Hyperparameters.""" + +import warnings + +from meridian import backend +from meridian import constants as c +from meridian.model import spec +from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb +from schema.serde import constants as sc +from schema.serde import serde +import numpy as np + +_MediaEffectsDist = meridian_pb.MediaEffectsDistribution +_PaidMediaPriorType = meridian_pb.PaidMediaPriorType +_NonPaidTreatmentsPriorType = meridian_pb.NonPaidTreatmentsPriorType + + +def _media_effects_dist_to_proto_enum( + media_effect_dict: str, +) -> _MediaEffectsDist: + match media_effect_dict: + case c.MEDIA_EFFECTS_LOG_NORMAL: + return _MediaEffectsDist.LOG_NORMAL + case c.MEDIA_EFFECTS_NORMAL: + return _MediaEffectsDist.NORMAL + case _: + return _MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED + + +def _proto_enum_to_media_effects_dist( + proto_enum: _MediaEffectsDist, +) -> str: + """Converts a `_MediaEffectsDist` enum to its string representation.""" + match proto_enum: + case _MediaEffectsDist.LOG_NORMAL: + return c.MEDIA_EFFECTS_LOG_NORMAL + case _MediaEffectsDist.NORMAL: + return c.MEDIA_EFFECTS_NORMAL + case _: + raise ValueError( + "Unsupported MediaEffectsDistribution proto enum value:" + f" {proto_enum}." + ) + + +def _paid_media_prior_type_to_proto_enum( + paid_media_prior_type: str | None, +) -> _PaidMediaPriorType: + """Converts a paid media prior type string to its proto enum.""" + if paid_media_prior_type is None: + return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED + try: + return _PaidMediaPriorType.Value(paid_media_prior_type.upper()) + except ValueError: + warnings.warn( + f"Invalid paid media prior type: {paid_media_prior_type}. Resolving to" + " PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED." + ) + return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED + + +def _proto_enum_to_paid_media_prior_type( + proto_enum: _PaidMediaPriorType, +) -> str | None: + """Converts a `_PaidMediaPriorType` enum to its string representation.""" + if proto_enum == _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED: + return None + return _PaidMediaPriorType.Name(proto_enum).lower() + + +def _non_paid_prior_type_to_proto_enum( + non_paid_prior_type: str, +) -> _NonPaidTreatmentsPriorType: + """Converts a non-paid prior type string to its proto enum.""" + try: + return _NonPaidTreatmentsPriorType.Value( + f"NON_PAID_TREATMENTS_PRIOR_TYPE_{non_paid_prior_type.upper()}" + ) + except ValueError: + warnings.warn( + f"Invalid non-paid prior type: {non_paid_prior_type}. Resolving to" + " NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION." + ) + return ( + _NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION + ) + + +def _proto_enum_to_non_paid_prior_type( + proto_enum: _NonPaidTreatmentsPriorType, +) -> str: + """Converts a `_NonPaidTreatmentsPriorType` enum to its string representation.""" + if ( + proto_enum + == _NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_UNSPECIFIED + ): + warnings.warn( + "Non-paid prior type is unspecified. Resolving to 'contribution'." + ) + return c.TREATMENT_PRIOR_TYPE_CONTRIBUTION + return ( + _NonPaidTreatmentsPriorType.Name(proto_enum) + .replace("NON_PAID_TREATMENTS_PRIOR_TYPE_", "") + .lower() + ) + + +class HyperparametersSerde( + serde.Serde[meridian_pb.Hyperparameters, spec.ModelSpec] +): + """Serializes and deserializes a ModelSpec into a `Hyperparameters` proto. + + Note that this Serde only handles the Hyperparameters part of ModelSpec. + The 'prior' attribute of ModelSpec is serialized/deserialized separately + using DistributionSerde. + """ + + def serialize(self, obj: spec.ModelSpec) -> meridian_pb.Hyperparameters: + """Serializes the given ModelSpec into a `Hyperparameters` proto.""" + hyperparameters_proto = meridian_pb.Hyperparameters( + media_effects_dist=_media_effects_dist_to_proto_enum( + obj.media_effects_dist + ), + hill_before_adstock=obj.hill_before_adstock, + unique_sigma_for_each_geo=obj.unique_sigma_for_each_geo, + media_prior_type=_paid_media_prior_type_to_proto_enum( + obj.media_prior_type + ), + rf_prior_type=_paid_media_prior_type_to_proto_enum(obj.rf_prior_type), + paid_media_prior_type=_paid_media_prior_type_to_proto_enum( + obj.paid_media_prior_type + ), + organic_media_prior_type=_non_paid_prior_type_to_proto_enum( + obj.organic_media_prior_type + ), + organic_rf_prior_type=_non_paid_prior_type_to_proto_enum( + obj.organic_rf_prior_type + ), + non_media_treatments_prior_type=_non_paid_prior_type_to_proto_enum( + obj.non_media_treatments_prior_type + ), + enable_aks=obj.enable_aks, + ) + if obj.max_lag is not None: + hyperparameters_proto.max_lag = obj.max_lag + + if isinstance(obj.knots, int): + hyperparameters_proto.knots.append(obj.knots) + elif isinstance(obj.knots, list): + hyperparameters_proto.knots.extend(obj.knots) + + if isinstance(obj.baseline_geo, str): + hyperparameters_proto.baseline_geo_string = obj.baseline_geo + elif isinstance(obj.baseline_geo, int): + hyperparameters_proto.baseline_geo_int = obj.baseline_geo + + if obj.roi_calibration_period is not None: + hyperparameters_proto.roi_calibration_period.CopyFrom( + backend.make_tensor_proto(np.array(obj.roi_calibration_period)) + ) + if obj.rf_roi_calibration_period is not None: + hyperparameters_proto.rf_roi_calibration_period.CopyFrom( + backend.make_tensor_proto(np.array(obj.rf_roi_calibration_period)) + ) + if obj.holdout_id is not None: + hyperparameters_proto.holdout_id.CopyFrom( + backend.make_tensor_proto(np.array(obj.holdout_id)) + ) + if obj.control_population_scaling_id is not None: + hyperparameters_proto.control_population_scaling_id.CopyFrom( + backend.make_tensor_proto(np.array(obj.control_population_scaling_id)) + ) + if obj.non_media_population_scaling_id is not None: + hyperparameters_proto.non_media_population_scaling_id.CopyFrom( + backend.make_tensor_proto( + np.array(obj.non_media_population_scaling_id) + ) + ) + + if isinstance(obj.adstock_decay_spec, str): + hyperparameters_proto.global_adstock_decay = obj.adstock_decay_spec + elif isinstance(obj.adstock_decay_spec, dict): + hyperparameters_proto.adstock_decay_by_channel.channel_decays.update( + obj.adstock_decay_spec + ) + + return hyperparameters_proto + + def deserialize( + self, + serialized: meridian_pb.Hyperparameters, + serialized_version: str = "", + ) -> spec.ModelSpec: + """Deserializes the given `Hyperparameters` proto into a ModelSpec. + + Note that this only deserializes the Hyperparameters part of ModelSpec. + The 'prior' attribute of ModelSpec is deserialized separately + using DistributionSerde and should be combined in the MeridianSerde. + + Args: + serialized: The serialized `Hyperparameters` proto. + serialized_version: The version of the serialized model. + + Returns: + A Meridian model spec container. + """ + baseline_geo = None + baseline_geo_field = serialized.WhichOneof(sc.BASELINE_GEO_ONEOF) + if baseline_geo_field == sc.BASELINE_GEO_INT: + baseline_geo = serialized.baseline_geo_int + elif baseline_geo_field == sc.BASELINE_GEO_STRING: + baseline_geo = serialized.baseline_geo_string + + knots = None + if serialized.knots: + if len(serialized.knots) == 1: + knots = serialized.knots[0] + else: + knots = list(serialized.knots) + + max_lag = serialized.max_lag if serialized.HasField(c.MAX_LAG) else None + + roi_calibration_period = ( + backend.make_ndarray(serialized.roi_calibration_period) + if serialized.HasField(c.ROI_CALIBRATION_PERIOD) + else None + ) + rf_roi_calibration_period = ( + backend.make_ndarray(serialized.rf_roi_calibration_period) + if serialized.HasField(c.RF_ROI_CALIBRATION_PERIOD) + else None + ) + + holdout_id = ( + backend.make_ndarray(serialized.holdout_id) + if serialized.HasField(sc.HOLDOUT_ID) + else None + ) + + control_population_scaling_id = ( + backend.make_ndarray(serialized.control_population_scaling_id) + if serialized.HasField(sc.CONTROL_POPULATION_SCALING_ID) + else None + ) + + non_media_population_scaling_id = ( + backend.make_ndarray(serialized.non_media_population_scaling_id) + if serialized.HasField(sc.NON_MEDIA_POPULATION_SCALING_ID) + else None + ) + + adstock_decay_spec_field = serialized.WhichOneof(sc.ADSTOCK_DECAY_SPEC) + if adstock_decay_spec_field == sc.GLOBAL_ADSTOCK_DECAY: + adstock_decay_spec = serialized.global_adstock_decay + elif adstock_decay_spec_field == sc.ADSTOCK_DECAY_BY_CHANNEL: + adstock_decay_spec = dict( + serialized.adstock_decay_by_channel.channel_decays + ) + else: + adstock_decay_spec = sc.DEFAULT_DECAY + + return spec.ModelSpec( + media_effects_dist=_proto_enum_to_media_effects_dist( + serialized.media_effects_dist + ), + hill_before_adstock=serialized.hill_before_adstock, + max_lag=max_lag, + unique_sigma_for_each_geo=serialized.unique_sigma_for_each_geo, + media_prior_type=_proto_enum_to_paid_media_prior_type( + serialized.media_prior_type + ), + rf_prior_type=_proto_enum_to_paid_media_prior_type( + serialized.rf_prior_type + ), + paid_media_prior_type=_proto_enum_to_paid_media_prior_type( + serialized.paid_media_prior_type + ), + organic_media_prior_type=_proto_enum_to_non_paid_prior_type( + serialized.organic_media_prior_type + ), + organic_rf_prior_type=_proto_enum_to_non_paid_prior_type( + serialized.organic_rf_prior_type + ), + non_media_treatments_prior_type=_proto_enum_to_non_paid_prior_type( + serialized.non_media_treatments_prior_type + ), + knots=knots, + enable_aks=serialized.enable_aks, + baseline_geo=baseline_geo, + roi_calibration_period=roi_calibration_period, + rf_roi_calibration_period=rf_roi_calibration_period, + holdout_id=holdout_id, + control_population_scaling_id=control_population_scaling_id, + non_media_population_scaling_id=non_media_population_scaling_id, + adstock_decay_spec=adstock_decay_spec, + ) diff --git a/schema/serde/inference_data.py b/schema/serde/inference_data.py new file mode 100644 index 000000000..ac812a407 --- /dev/null +++ b/schema/serde/inference_data.py @@ -0,0 +1,104 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Serialization and deserialization of `InferenceData` container for sampled priors and posteriors.""" + +import io + +import arviz as az +from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb +from schema.serde import serde +import xarray as xr + + +_NETCDF_FORMAT = "NETCDF3_64BIT" # scipy only supports up to v3 +_PRIOR_FIELD = "prior" +_POSTERIOR_FIELD = "posterior" +_CREATED_AT_ATTRIBUTE = "created_at" + + +def _remove_created_at_attribute(dataset: xr.Dataset) -> xr.Dataset: + dataset_copy = dataset.copy() + if _CREATED_AT_ATTRIBUTE in dataset_copy.attrs: + del dataset_copy.attrs[_CREATED_AT_ATTRIBUTE] + return dataset_copy + + +class InferenceDataSerde( + serde.Serde[meridian_pb.InferenceData, az.InferenceData] +): + """Serializes and deserializes an `InferenceData` container in Meridian. + + Meridian uses `InferenceData` as a container to store sampled prior and + posterior containers. + """ + + def serialize(self, obj: az.InferenceData) -> meridian_pb.InferenceData: + """Serializes the given Meridian inference data container into an `InferenceData` proto.""" + if hasattr(obj, _PRIOR_FIELD): + prior_dataset_copy = _remove_created_at_attribute(obj.prior) # pytype: disable=attribute-error + prior_bytes = bytes(prior_dataset_copy.to_netcdf(format=_NETCDF_FORMAT)) + else: + prior_bytes = None + + if hasattr(obj, _POSTERIOR_FIELD): + posterior_dataset_copy = _remove_created_at_attribute(obj.posterior) # pytype: disable=attribute-error + posterior_bytes = bytes( + posterior_dataset_copy.to_netcdf(format=_NETCDF_FORMAT) + ) + else: + posterior_bytes = None + + aux = {} + for group in obj.groups(): + if group in (_PRIOR_FIELD, _POSTERIOR_FIELD): + continue + aux_dataset_copy = _remove_created_at_attribute(obj.get(group)) + aux[group] = bytes(aux_dataset_copy.to_netcdf(format=_NETCDF_FORMAT)) + + return meridian_pb.InferenceData( + prior=prior_bytes, + posterior=posterior_bytes, + auxiliary_data=aux, + ) + + def deserialize( + self, serialized: meridian_pb.InferenceData, serialized_version: str = "" + ) -> az.InferenceData: + """Deserializes the given `InferenceData` proto. + + Args: + serialized: The serialized `InferenceData` proto. + serialized_version: The version of the serialized model. + + Returns: + A Meridian inference data container. + """ + groups = {} + + if serialized.HasField(_PRIOR_FIELD): + prior_dataset = xr.open_dataset(io.BytesIO(serialized.prior)) + groups[_PRIOR_FIELD] = prior_dataset + + if serialized.HasField(_POSTERIOR_FIELD): + posterior_dataset = xr.open_dataset(io.BytesIO(serialized.posterior)) + groups[_POSTERIOR_FIELD] = posterior_dataset + + for name, data in serialized.auxiliary_data.items(): + groups[name] = xr.open_dataset(io.BytesIO(data)) + + idata = az.InferenceData() + if groups: + idata.add_groups(groups) + return idata diff --git a/schema/serde/inference_data_test.py b/schema/serde/inference_data_test.py new file mode 100644 index 000000000..e2843485b --- /dev/null +++ b/schema/serde/inference_data_test.py @@ -0,0 +1,243 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import arviz as az +from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb +from schema.serde import inference_data as infdata +import numpy as np +import xarray as xr +from tensorflow.python.util.protobuf import compare + + +mock = absltest.mock + + +def _mock_infdata_with_datasets( + include_prior=False, + include_posterior=False, # Also include debug sample_stats and trace mocks. + to_netcdf_returns_memoryview=False, +) -> mock.MagicMock: + idata = mock.MagicMock(spec=az.InferenceData) + mock_groups = {} + + attrs = {"created_at": "2024-08-20T12:00:00.000000000Z"} + get_return_value = ( + lambda bytes: memoryview(bytes) if to_netcdf_returns_memoryview else bytes + ) + + if include_prior: + prior = mock.MagicMock(spec=xr.Dataset) + prior.name = "prior" + prior.attrs = attrs.copy() + prior.to_netcdf.return_value = get_return_value(b"test-prior-bytes") + prior.copy.return_value = prior + idata.prior = prior + mock_groups["prior"] = prior + + if include_posterior: + posterior = mock.MagicMock(spec=xr.Dataset) + posterior.name = "posterior" + posterior.attrs = attrs.copy() + posterior.to_netcdf.return_value = get_return_value(b"test-posterior-bytes") + posterior.copy.return_value = posterior + idata.posterior = posterior + mock_groups["posterior"] = posterior + + sample_stats = mock.MagicMock(spec=xr.Dataset) + sample_stats.name = "sample_stats" + sample_stats.attrs = attrs.copy() + sample_stats.to_netcdf.return_value = get_return_value(b"test-stats-bytes") + sample_stats.copy.return_value = sample_stats + mock_groups["sample_stats"] = sample_stats + + trace = mock.MagicMock(spec=xr.Dataset) + trace.name = "trace" + trace.attrs = attrs.copy() + trace.to_netcdf.return_value = get_return_value(b"test-trace-bytes") + trace.copy.return_value = trace + mock_groups["trace"] = trace + + idata.groups.return_value = mock_groups.keys() + + def _get(group: str) -> xr.Dataset: + return mock_groups[group] + + idata.get.side_effect = _get + + return idata + + +def _create_random_infdata(group: str) -> az.InferenceData: + shape = (1, 2, 3, 4, 5) + dataset = az.convert_to_inference_data(np.random.randn(*shape), group=group) + + idata = az.InferenceData() + idata.extend(dataset, join="right") + return idata + + +def _create_prior_infdata() -> az.InferenceData: + return _create_random_infdata("prior") + + +def _create_posterior_infdata() -> az.InferenceData: + return az.concat( + _create_random_infdata("posterior"), + _create_random_infdata("trace"), + _create_random_infdata("sample_stats"), + ) + + +def _create_fully_fitted_infdata() -> az.InferenceData: + prior = _create_prior_infdata() + posterior = _create_posterior_infdata() + prior.extend(posterior, join="right") + return prior + + +class InferenceDataTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.serde = infdata.InferenceDataSerde() + + def test_serialize_no_sampled_data(self): + infdata_proto = self.serde.serialize(az.InferenceData()) + compare.assertProtoEqual( + self, + infdata_proto, + meridian_pb.InferenceData(), + ) + + @parameterized.named_parameters( + dict( + testcase_name="prior_only", + idata=_mock_infdata_with_datasets( + include_prior=True, + include_posterior=False, + ), + expected_infdata_proto=meridian_pb.InferenceData( + prior=b"test-prior-bytes", + ), + ), + dict( + testcase_name="prior_only_memoryview", + idata=_mock_infdata_with_datasets( + include_prior=True, + include_posterior=False, + to_netcdf_returns_memoryview=True, + ), + expected_infdata_proto=meridian_pb.InferenceData( + prior=b"test-prior-bytes", + ), + ), + dict( + testcase_name="posterior_only", + idata=_mock_infdata_with_datasets( + include_prior=False, + include_posterior=True, + ), + expected_infdata_proto=meridian_pb.InferenceData( + posterior=b"test-posterior-bytes", + auxiliary_data={ + "sample_stats": b"test-stats-bytes", + "trace": b"test-trace-bytes", + }, + ), + ), + dict( + testcase_name="posterior_only_memoryview", + idata=_mock_infdata_with_datasets( + include_prior=False, + include_posterior=True, + to_netcdf_returns_memoryview=True, + ), + expected_infdata_proto=meridian_pb.InferenceData( + posterior=b"test-posterior-bytes", + auxiliary_data={ + "sample_stats": b"test-stats-bytes", + "trace": b"test-trace-bytes", + }, + ), + ), + dict( + testcase_name="fully_fitted", + idata=_mock_infdata_with_datasets( + include_prior=True, + include_posterior=True, + ), + expected_infdata_proto=meridian_pb.InferenceData( + prior=b"test-prior-bytes", + posterior=b"test-posterior-bytes", + auxiliary_data={ + "sample_stats": b"test-stats-bytes", + "trace": b"test-trace-bytes", + }, + ), + ), + dict( + testcase_name="fully_fitted_memoryview", + idata=_mock_infdata_with_datasets( + include_prior=True, + include_posterior=True, + to_netcdf_returns_memoryview=True, + ), + expected_infdata_proto=meridian_pb.InferenceData( + prior=b"test-prior-bytes", + posterior=b"test-posterior-bytes", + auxiliary_data={ + "sample_stats": b"test-stats-bytes", + "trace": b"test-trace-bytes", + }, + ), + ), + ) + def test_serialize( + self, + idata: mock.MagicMock, + expected_infdata_proto: meridian_pb.InferenceData, + ): + infdata_proto = self.serde.serialize(idata) + compare.assertProtoEqual(self, infdata_proto, expected_infdata_proto) + self.assertNotIn("created_at", idata.attrs) + + def test_deserialize_empty(self): + infdata_proto = meridian_pb.InferenceData() + idata = self.serde.deserialize(infdata_proto) + self.assertEqual(idata, az.InferenceData()) + + @parameterized.named_parameters( + dict( + testcase_name="prior_only", + idata=_create_prior_infdata(), + ), + dict( + testcase_name="posterior_only", + idata=_create_posterior_infdata(), + ), + dict( + testcase_name="fully_fitted", + idata=_create_fully_fitted_infdata(), + ), + ) + def test_serialize_deserialize(self, idata: az.InferenceData): + ser = self.serde.serialize(idata) + de = self.serde.deserialize(ser) + self.assertEqual(de, idata) + + +if __name__ == "__main__": + absltest.main() diff --git a/schema/serde/marketing_data.py b/schema/serde/marketing_data.py new file mode 100644 index 000000000..96d264084 --- /dev/null +++ b/schema/serde/marketing_data.py @@ -0,0 +1,1320 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Serialization and deserialization of `InputData` for Meridian models.""" + +from collections.abc import Mapping +import dataclasses +import datetime as dt +import functools +import itertools +from typing import Sequence + +from meridian import constants as c +from meridian.data import input_data as meridian_input_data +from mmm.v1.common import date_interval_pb2 +from mmm.v1.marketing import marketing_data_pb2 as marketing_pb +from schema.serde import constants as sc +from schema.serde import serde +from schema.utils import time_record +import numpy as np +import xarray as xr + +from google.type import date_pb2 + +# Mapping from DataArray names to coordinate names +_COORD_NAME_MAP = { + c.MEDIA: c.MEDIA_CHANNEL, + c.REACH: c.RF_CHANNEL, + c.FREQUENCY: c.RF_CHANNEL, + c.ORGANIC_MEDIA: c.ORGANIC_MEDIA_CHANNEL, + c.ORGANIC_REACH: c.ORGANIC_RF_CHANNEL, + c.ORGANIC_FREQUENCY: c.ORGANIC_RF_CHANNEL, + c.NON_MEDIA_TREATMENTS: c.NON_MEDIA_CHANNEL, +} + + +@dataclasses.dataclass(frozen=True) +class _DeserializedTimeDimension: + """Wrapper class for `TimeDimension` proto to provide utility methods during deserialization.""" + + _time_dimension: marketing_pb.MarketingDataMetadata.TimeDimension + + def __post_init__(self): + if not self._time_dimension.dates: + raise ValueError("TimeDimension proto must have at least one date.") + + @functools.cached_property + def date_coordinates(self) -> list[dt.date]: + """Returns a list of date coordinates in this time dimension.""" + return [dt.date(d.year, d.month, d.day) for d in self._time_dimension.dates] + + @functools.cached_property + def time_dimension_interval(self) -> date_interval_pb2.DateInterval: + """Returns the `[start, end)` interval that spans this time dimension. + + This date interval spans all of the date coordinates in this time dimension. + """ + date_intervals = time_record.convert_times_to_date_intervals( + self.date_coordinates + ) + return _get_date_interval_from_date_intervals(list(date_intervals.values())) + + +@dataclasses.dataclass(frozen=True) +class _DeserializedMetadata: + """A container for parsed metadata from the `MarketingData` proto. + + Attributes: + _metadata: The `MarketingDataMetadata` proto. + """ + + _metadata: marketing_pb.MarketingDataMetadata + + def __post_init__(self): + # Evaluate the properties to trigger validation + _ = self.time_dimension + _ = self.media_time_dimension + + def _get_time_dimension(self, name: str) -> _DeserializedTimeDimension: + """Helper method to get a specific TimeDimension proto by name.""" + for time_dimension in self._metadata.time_dimensions: + if time_dimension.name == name: + return _DeserializedTimeDimension(time_dimension) + raise ValueError(f"No TimeDimension found with name '{name}' in metadata.") + + @functools.cached_property + def time_dimension(self) -> _DeserializedTimeDimension: + """Returns the TimeDimension with name 'time'.""" + return self._get_time_dimension(c.TIME) + + @functools.cached_property + def media_time_dimension(self) -> _DeserializedTimeDimension: + """Returns the TimeDimension with name 'media_time'.""" + return self._get_time_dimension(c.MEDIA_TIME) + + @functools.cached_property + def channel_dimensions(self) -> Mapping[str, list[str]]: + """Returns a mapping of channel dimension names to their corresponding channel coordinate names.""" + return { + cd.name: list(cd.channels) for cd in self._metadata.channel_dimensions + } + + @functools.cached_property + def channel_types(self) -> Mapping[str, str | None]: + """Returns a mapping of individual channel names to their types.""" + channel_coord_map = {} + for name, channels in self.channel_dimensions.items(): + for channel in channels: + channel_coord_map[channel] = _COORD_NAME_MAP.get( + name, + ) + return channel_coord_map + + +def _extract_data_array( + serialized_data_points: Sequence[marketing_pb.MarketingDataPoint], + data_extractor_fn, + data_name, +) -> xr.DataArray | None: + """Helper function to extract data into an `xr.DataArray`. + + Args: + serialized_data_points: A Sequence of MarketingDataPoint protos. + data_extractor_fn: A function that takes a data point and returns either a + tuple of `(geo_id, time_str, value)`, or `None` if the data point should + be skipped. + data_name: The desired name for the `xr.DataArray`. + + Returns: + An `xr.DataArray` containing the extracted data, or `None` if no data is + found. + """ + data_dict = {} # (geo_id, time_str) -> value + geo_ids = [] + times = [] + + for data_point in serialized_data_points: + extraction_result = data_extractor_fn(data_point) + if extraction_result is None: + continue + + geo_id, time_str, value = extraction_result + + # TODO: Enforce dimension uniqueness in Meridian. + if geo_id not in geo_ids: + geo_ids.append(geo_id) + if time_str not in times: + times.append(time_str) + + data_dict[(geo_id, time_str)] = value + + if not data_dict: + return None + + data_values = np.array([ + [data_dict.get((geo_id, time), np.nan) for time in times] + for geo_id in geo_ids + ]) + + return xr.DataArray( + data=data_values, + coords={ + c.GEO: geo_ids, + c.TIME: times, + }, + dims=(c.GEO, c.TIME), + name=data_name, + ) + + +def _extract_3d_data_array( + serialized_data_points: Sequence[marketing_pb.MarketingDataPoint], + data_extractor_fn, + data_name, + third_dim_name, + time_dim_name=c.TIME, +) -> xr.DataArray | None: + """Helper function to extract data with 3 dimensions into an `xr.DataArray`. + + The first dimension is always `GEO`, and the second is the time dimension + (default: `TIME`). + + Args: + serialized_data_points: A sequence of MarketingDataPoint protos. + data_extractor_fn: A function that takes a data point and returns either a + tuple of `(geo_id, time_str, third_dim_key, value)`, or `None` if the + data point should be skipped. + data_name: The desired name for the `xr.DataArray`. + third_dim_name: The name of the third dimension. + time_dim_name: The name of the time dimension. Default is `TIME`. + + Returns: + An `xr.DataArray` containing the extracted data, or `None` if no data is + found. + """ + data_dict = {} # (geo_id, time_str, third_dim_key) -> value + geo_ids = [] + times = [] + third_dim_keys = [] + + for data_point in serialized_data_points: + for extraction_result in data_extractor_fn(data_point): + geo_id, time_str, third_dim_key, value = extraction_result + + if geo_id not in geo_ids: + geo_ids.append(geo_id) + if time_str not in times: + times.append(time_str) + if third_dim_key not in third_dim_keys: + third_dim_keys.append(third_dim_key) + + # TODO: Enforce dimension uniqueness in Meridian. + data_dict[(geo_id, time_str, third_dim_key)] = value + + if not data_dict: + return None + + data_values = np.array([ + [ + [ + data_dict.get((geo_id, time, third_dim_key), np.nan) + for third_dim_key in third_dim_keys + ] + for time in times + ] + for geo_id in geo_ids + ]) + + return xr.DataArray( + data=data_values, + coords={ + c.GEO: geo_ids, + time_dim_name: times, + third_dim_name: third_dim_keys, + }, + dims=(c.GEO, time_dim_name, third_dim_name), + name=data_name, + ) + + +def _get_date_interval_from_date_intervals( + date_intervals: Sequence[date_interval_pb2.DateInterval], +) -> date_interval_pb2.DateInterval: + """Gets the date interval based on the earliest start date and latest end date. + + Args: + date_intervals: A list of DateInterval protos. + + Returns: + A DateInterval representing the earliest start date and latest end date. + """ + get_start_date = lambda interval: dt.date( + interval.start_date.year, + interval.start_date.month, + interval.start_date.day, + ) + get_end_date = lambda interval: dt.date( + interval.end_date.year, interval.end_date.month, interval.end_date.day + ) + + min_start_date_interval = min(date_intervals, key=get_start_date) + max_end_date_interval = max(date_intervals, key=get_end_date) + + return date_interval_pb2.DateInterval( + start_date=date_pb2.Date( + year=min_start_date_interval.start_date.year, + month=min_start_date_interval.start_date.month, + day=min_start_date_interval.start_date.day, + ), + end_date=date_pb2.Date( + year=max_end_date_interval.end_date.year, + month=max_end_date_interval.end_date.month, + day=max_end_date_interval.end_date.day, + ), + ) + + +class _InputDataSerializer: + """Serializes an `InputData` container in Meridian model.""" + + def __init__(self, input_data: meridian_input_data.InputData): + self._input_data = input_data + + @property + def _n_geos(self) -> int: + return len(self._input_data.geo) + + @property + def _n_times(self) -> int: + return len(self._input_data.time) + + def __call__(self) -> marketing_pb.MarketingData: + """Serializes the input data into a MarketingData proto.""" + marketing_proto = marketing_pb.MarketingData() + # Use media_time since it covers larger range. + times_to_date_intervals = time_record.convert_times_to_date_intervals( + self._input_data.media_time.data + ) + geos_and_times = itertools.product( + self._input_data.geo.data, self._input_data.media_time.data + ) + + for geo, time in geos_and_times: + data_point = self._serialize_data_point( + geo, + time, + times_to_date_intervals, + ) + marketing_proto.marketing_data_points.append(data_point) + + if self._input_data.media_spend is not None: + if ( + not self._input_data.media_spend_has_geo_dimension + and not self._input_data.media_spend_has_time_dimension + ): + marketing_proto.marketing_data_points.append( + self._serialize_aggregated_media_spend_data_point( + self._input_data.media_spend, + times_to_date_intervals, + ) + ) + elif ( + self._input_data.media_spend_has_geo_dimension + != self._input_data.media_spend_has_time_dimension + ): + raise AssertionError( + "Invalid input data: media_spend must either be fully granular" + " (both geo and time dimensions) or fully aggregated (neither geo" + " nor time dimensions)." + ) + + if self._input_data.rf_spend is not None: + if ( + not self._input_data.rf_spend_has_geo_dimension + and not self._input_data.rf_spend_has_time_dimension + ): + marketing_proto.marketing_data_points.append( + self._serialize_aggregated_rf_spend_data_point( + self._input_data.rf_spend, times_to_date_intervals + ) + ) + elif ( + self._input_data.rf_spend_has_geo_dimension + != self._input_data.rf_spend_has_time_dimension + ): + raise AssertionError( + "Invalid input data: rf_spend must either be fully granular (both" + " geo and time dimensions) or fully aggregated (neither geo nor" + " time dimensions)." + ) + + marketing_proto.metadata.CopyFrom(self._serialize_metadata()) + + return marketing_proto + + def _serialize_media_variables( + self, + geo: str, + time: str, + channel_dim_name: str, + impressions_data_array: xr.DataArray, + spend_data_array: xr.DataArray | None = None, + ) -> list[marketing_pb.MediaVariable]: + """Serializes media variables for a given geo and time. + + Args: + geo: The geo ID. + time: The time string. + channel_dim_name: The name of the channel dimension. + impressions_data_array: The DataArray containing impressions data. + Expected dimensions: `(n_geos, n_media_times, n_channels)`. + spend_data_array: The optional DataArray containing spend data. Expected + dimensions are `(n_geos, n_times, n_media_channels)`. + + Returns: + A list of MediaVariable protos. + """ + media_variables = [] + for media_data in impressions_data_array.sel(geo=geo, media_time=time): + channel = media_data[channel_dim_name].item() + media_variable = marketing_pb.MediaVariable( + channel_name=channel, + scalar_metric=marketing_pb.ScalarMetric( + name=c.IMPRESSIONS, value=media_data.item() + ), + ) + if spend_data_array is not None and time in spend_data_array.time: + media_variable.media_spend = spend_data_array.sel( + geo=geo, time=time, **{channel_dim_name: channel} + ).item() + media_variables.append(media_variable) + return media_variables + + def _serialize_reach_frequency_variables( + self, + geo: str, + time: str, + channel_dim_name: str, + reach_data_array: xr.DataArray, + frequency_data_array: xr.DataArray, + spend_data_array: xr.DataArray | None = None, + ) -> list[marketing_pb.ReachFrequencyVariable]: + """Serializes reach and frequency variables for a given geo and time. + + Iterates through the R&F channels separately, creating a MediaVariable + for each. It's safe to assume that Meridian media channel names are + unique across `media_data` and `reach_data`. This assumption is + checked when an `InputData` is created in model training. + + Dimensions of `reach_data_array` and `frequency_data_array` are expected + to be `(n_geos, n_media_times, n_rf_channels)`. + + Args: + geo: The geo ID. + time: The time string. + channel_dim_name: The name of the channel dimension (e.g., 'rf_channel'). + reach_data_array: The DataArray containing reach data. + frequency_data_array: The DataArray containing frequency data. + spend_data_array: The optional DataArray containing spend data. + + Returns: + A list of MediaVariable protos. + """ + rf_variables = [] + for reach_data in reach_data_array.sel(geo=geo, media_time=time): + reach_value = reach_data.item() + channel = reach_data[channel_dim_name].item() + frequency_value = frequency_data_array.sel( + geo=geo, + media_time=time, + **{channel_dim_name: channel}, + ).item() + rf_variable = marketing_pb.ReachFrequencyVariable( + channel_name=channel, + reach=int(reach_value), + average_frequency=frequency_value, + ) + if spend_data_array is not None and time in spend_data_array.time: + rf_variable.spend = spend_data_array.sel( + geo=geo, time=time, **{channel_dim_name: channel} + ).item() + rf_variables.append(rf_variable) + return rf_variables + + def _serialize_non_media_treatment_variables( + self, geo: str, time: str + ) -> list[marketing_pb.NonMediaTreatmentVariable]: + """Serializes non-media treatment variables for a given geo and time. + + Args: + geo: The geo ID. + time: The time string. + + Returns: + A list of NonMediaTreatmentVariable protos. + """ + non_media_treatment_variables = [] + if ( + self._input_data.non_media_treatments is not None + and geo in self._input_data.non_media_treatments.geo + and time in self._input_data.non_media_treatments.time + ): + for non_media_treatment_data in self._input_data.non_media_treatments.sel( + geo=geo, time=time + ): + non_media_treatment_variables.append( + marketing_pb.NonMediaTreatmentVariable( + name=non_media_treatment_data[c.NON_MEDIA_CHANNEL].item(), + value=non_media_treatment_data.item(), + ) + ) + return non_media_treatment_variables + + def _serialize_data_point( + self, + geo: str, + time: str, + times_to_date_intervals: Mapping[str, date_interval_pb2.DateInterval], + ) -> marketing_pb.MarketingDataPoint: + """Serializes a MarketingDataPoint proto for a given geo and time.""" + data_point = marketing_pb.MarketingDataPoint( + geo_info=marketing_pb.GeoInfo( + geo_id=geo, + population=round(self._input_data.population.sel(geo=geo).item()), + ), + date_interval=times_to_date_intervals.get(time), + ) + + if self._input_data.controls is not None: + if time in self._input_data.controls.time: + for control_data in self._input_data.controls.sel(geo=geo, time=time): + data_point.control_variables.add( + name=control_data.control_variable.item(), + value=control_data.item(), + ) + + if self._input_data.media is not None: + if ( + self._input_data.media_spend_has_geo_dimension + and self._input_data.media_spend_has_time_dimension + ): + spend_data_array = self._input_data.media_spend + else: + # Aggregated spend data is serialized in a separate data point. + spend_data_array = None + media_variables = self._serialize_media_variables( + geo, + time, + c.MEDIA_CHANNEL, + self._input_data.media, + spend_data_array, + ) + data_point.media_variables.extend(media_variables) + + if ( + self._input_data.reach is not None + and self._input_data.frequency is not None + ): + if ( + self._input_data.rf_spend_has_geo_dimension + and self._input_data.rf_spend_has_time_dimension + ): + rf_spend_data_array = self._input_data.rf_spend + else: + # Aggregated spend data is serialized in a separate data point. + rf_spend_data_array = None + rf_variables = self._serialize_reach_frequency_variables( + geo, + time, + c.RF_CHANNEL, + self._input_data.reach, + self._input_data.frequency, + rf_spend_data_array, + ) + data_point.reach_frequency_variables.extend(rf_variables) + + if self._input_data.organic_media is not None: + organic_media_variables = self._serialize_media_variables( + geo, time, c.ORGANIC_MEDIA_CHANNEL, self._input_data.organic_media + ) + data_point.media_variables.extend(organic_media_variables) + + if ( + self._input_data.organic_reach is not None + and self._input_data.organic_frequency is not None + ): + organic_rf_variables = self._serialize_reach_frequency_variables( + geo, + time, + c.ORGANIC_RF_CHANNEL, + self._input_data.organic_reach, + self._input_data.organic_frequency, + ) + data_point.reach_frequency_variables.extend(organic_rf_variables) + + non_media_treatment_variables = ( + self._serialize_non_media_treatment_variables(geo, time) + ) + data_point.non_media_treatment_variables.extend( + non_media_treatment_variables + ) + + if time in self._input_data.kpi.time: + kpi_proto = self._make_kpi_proto(geo, time) + data_point.kpi.CopyFrom(kpi_proto) + + return data_point + + def _serialize_aggregated_media_spend_data_point( + self, + spend_data_array: xr.DataArray, + times_to_date_intervals: Mapping[str, date_interval_pb2.DateInterval], + ) -> marketing_pb.MarketingDataPoint: + """Serializes and appends a data point for aggregated spend.""" + spend_data_point = marketing_pb.MarketingDataPoint() + date_interval = _get_date_interval_from_date_intervals( + list(times_to_date_intervals.values()) + ) + spend_data_point.date_interval.CopyFrom(date_interval) + + for channel_name in spend_data_array.coords[c.MEDIA_CHANNEL].values: + spend_value = spend_data_array.sel( + **{c.MEDIA_CHANNEL: channel_name} + ).item() + spend_data_point.media_variables.add( + channel_name=channel_name, media_spend=spend_value + ) + + return spend_data_point + + def _serialize_aggregated_rf_spend_data_point( + self, + spend_data_array: xr.DataArray, + times_to_date_intervals: Mapping[str, date_interval_pb2.DateInterval], + ) -> marketing_pb.MarketingDataPoint: + """Serializes and appends a data point for aggregated spend.""" + spend_data_point = marketing_pb.MarketingDataPoint() + date_interval = _get_date_interval_from_date_intervals( + list(times_to_date_intervals.values()) + ) + spend_data_point.date_interval.CopyFrom(date_interval) + + for channel_name in spend_data_array.coords[c.RF_CHANNEL].values: + spend_value = spend_data_array.sel(**{c.RF_CHANNEL: channel_name}).item() + spend_data_point.reach_frequency_variables.add( + channel_name=channel_name, spend=spend_value + ) + + return spend_data_point + + def _serialize_time_dimensions( + self, name: str, time_data: xr.DataArray + ) -> marketing_pb.MarketingDataMetadata.TimeDimension: + """Creates a TimeDimension message.""" + time_dimensions = marketing_pb.MarketingDataMetadata.TimeDimension( + name=name + ) + for date in time_data.values: + date_obj = dt.datetime.strptime(date, c.DATE_FORMAT).date() + time_dimensions.dates.add( + year=date_obj.year, month=date_obj.month, day=date_obj.day + ) + return time_dimensions + + def _serialize_channel_dimensions( + self, channel_data: xr.DataArray | None + ) -> marketing_pb.MarketingDataMetadata.ChannelDimension | None: + """Creates a ChannelDimension message if the corresponding attribute exists.""" + if channel_data is None: + return None + + coord_name = _COORD_NAME_MAP.get(channel_data.name) + if coord_name: + return marketing_pb.MarketingDataMetadata.ChannelDimension( + name=channel_data.name, + channels=channel_data.coords[coord_name].values.tolist(), + ) + else: + # Make sure that all channel dimensions are handled. + raise ValueError(f"Unknown channel data name: {channel_data.name}. ") + + def _serialize_metadata(self) -> marketing_pb.MarketingDataMetadata: + """Serializes metadata from InputData to MarketingDataMetadata.""" + metadata = marketing_pb.MarketingDataMetadata() + + metadata.time_dimensions.append( + self._serialize_time_dimensions(c.TIME, self._input_data.time) + ) + metadata.time_dimensions.append( + self._serialize_time_dimensions( + c.MEDIA_TIME, self._input_data.media_time + ) + ) + + channel_data_arrays = [ + self._input_data.media, + self._input_data.reach, + self._input_data.frequency, + self._input_data.organic_media, + self._input_data.organic_reach, + self._input_data.organic_frequency, + ] + + for channel_data_array in channel_data_arrays: + channel_names_message = self._serialize_channel_dimensions( + channel_data_array + ) + if channel_names_message: + metadata.channel_dimensions.append(channel_names_message) + + if self._input_data.controls is not None: + metadata.control_names.extend( + self._input_data.controls.control_variable.values + ) + + if self._input_data.non_media_treatments is not None: + metadata.non_media_treatment_names.extend( + self._input_data.non_media_treatments.non_media_channel.values + ) + + metadata.kpi_type = self._input_data.kpi_type + + return metadata + + def _make_kpi_proto(self, geo: str, time: str) -> marketing_pb.Kpi: + """Constructs a Kpi proto from the TrainedModel.""" + kpi_proto = marketing_pb.Kpi(name=self._input_data.kpi_type) + # `kpi` and `revenue_per_kpi` dimensions: `(n_geos, n_times)`. + if self._input_data.kpi_type == c.REVENUE: + kpi_proto.revenue.CopyFrom( + marketing_pb.Kpi.Revenue( + value=self._input_data.kpi.sel(geo=geo, time=time).item() + ) + ) + else: + kpi_proto.non_revenue.CopyFrom( + marketing_pb.Kpi.NonRevenue( + value=self._input_data.kpi.sel(geo=geo, time=time).item() + ) + ) + if self._input_data.revenue_per_kpi is not None: + kpi_proto.non_revenue.revenue_per_kpi = ( + self._input_data.revenue_per_kpi.sel(geo=geo, time=time).item() + ) + return kpi_proto + + +class _InputDataDeserializer: + """Deserializes a `MarketingData` proto into a Meridian `InputData`.""" + + def __init__(self, serialized: marketing_pb.MarketingData): + self._serialized = serialized + + def __post_init__(self): + if not self._serialized.HasField(sc.METADATA): + raise ValueError( + f"MarketingData proto is missing the '{sc.METADATA}' field." + ) + + @functools.cached_property + def _metadata(self) -> _DeserializedMetadata: + """Parses metadata and extracts time dimensions, channel dimensions, and channel type map.""" + return _DeserializedMetadata(self._serialized.metadata) + + def _extract_population(self) -> xr.DataArray: + """Extracts population data from the serialized proto.""" + geo_populations = {} + + for data_point in self._serialized.marketing_data_points: + geo_id = data_point.geo_info.geo_id + if not geo_id: + continue + + geo_populations[geo_id] = data_point.geo_info.population + + return xr.DataArray( + coords={c.GEO: list(geo_populations.keys())}, + data=np.array(list(geo_populations.values())), + name=c.POPULATION, + ) + + def _extract_kpi_type(self) -> str: + """Extracts the kpi_type from the serialized proto.""" + kpi_type = None + for data_point in self._serialized.marketing_data_points: + if data_point.HasField(c.KPI): + current_kpi_type = data_point.kpi.WhichOneof(c.TYPE) + + if kpi_type is None: + kpi_type = current_kpi_type + elif kpi_type != current_kpi_type: + raise ValueError( + "Inconsistent kpi_type found in the data. " + f"Expected {kpi_type}, found {current_kpi_type}" + ) + + if kpi_type is None: + raise ValueError("kpi_type not found in the data.") + return kpi_type + + def _extract_geo_and_time(self, data_point) -> tuple[str | None, str]: + """Extracts geo_id and time_str from a data_point.""" + geo_id = data_point.geo_info.geo_id + start_date = data_point.date_interval.start_date + time_str = dt.datetime( + start_date.year, start_date.month, start_date.day + ).strftime(c.DATE_FORMAT) + return geo_id, time_str + + def _extract_kpi(self, kpi_type: str) -> xr.DataArray: + """Extracts KPI data from the serialized proto.""" + + def _kpi_extractor(data_point): + if not data_point.HasField(c.KPI): + return None + + geo_id, time_str = self._extract_geo_and_time(data_point) + + if data_point.kpi.WhichOneof(c.TYPE) != kpi_type: + raise ValueError( + "Inconsistent kpi_type found in the data. " + f"Expected {kpi_type}, found" + f" {data_point.kpi.WhichOneof(c.TYPE)}" + ) + + kpi_value = ( + data_point.kpi.revenue.value + if kpi_type == c.REVENUE + else data_point.kpi.non_revenue.value + ) + return geo_id, time_str, kpi_value + + kpi = _extract_data_array( + serialized_data_points=self._serialized.marketing_data_points, + data_extractor_fn=_kpi_extractor, + data_name=c.KPI, + ) + + if kpi is None: + raise ValueError(f"{c.KPI} is not found in the data.") + + return kpi + + def _extract_revenue_per_kpi(self, kpi_type: str) -> xr.DataArray | None: + """Extracts revenue per KPI data from the serialized proto.""" + + if kpi_type == c.REVENUE: + raise ValueError( + f"{c.REVENUE_PER_KPI} is not applicable when kpi_type is {c.REVENUE}." + ) + + def _revenue_per_kpi_extractor(data_point): + if not data_point.HasField(c.KPI): + return None + + if not data_point.kpi.non_revenue.HasField(c.REVENUE_PER_KPI): + return None + + geo_id, time_str = self._extract_geo_and_time(data_point) + + if data_point.kpi.WhichOneof(c.TYPE) != kpi_type: + raise ValueError( + "Inconsistent kpi_type found in the data. " + f"Expected {kpi_type}, found" + f" {data_point.kpi.WhichOneof(c.TYPE)}" + ) + + return geo_id, time_str, data_point.kpi.non_revenue.revenue_per_kpi + + return _extract_data_array( + serialized_data_points=self._serialized.marketing_data_points, + data_extractor_fn=_revenue_per_kpi_extractor, + data_name=c.REVENUE_PER_KPI, + ) + + def _extract_controls(self) -> xr.DataArray | None: + """Extracts control variables data from the serialized proto, if any.""" + + def _controls_extractor(data_point): + if not data_point.control_variables: + return None + + geo_id, time_str = self._extract_geo_and_time(data_point) + + for control_variable in data_point.control_variables: + control_name = control_variable.name + control_value = control_variable.value + yield geo_id, time_str, control_name, control_value + + return _extract_3d_data_array( + serialized_data_points=self._serialized.marketing_data_points, + data_extractor_fn=_controls_extractor, + data_name=c.CONTROLS, + third_dim_name=c.CONTROL_VARIABLE, + ) + + def _extract_media(self) -> xr.DataArray | None: + """Extracts media variables data from the serialized proto.""" + + def _media_extractor(data_point): + geo_id, time_str = self._extract_geo_and_time(data_point) + + if not geo_id: + return None + + for media_variable in data_point.media_variables: + channel_name = media_variable.channel_name + if self._metadata.channel_types.get(channel_name) != c.MEDIA_CHANNEL: + continue + + media_value = media_variable.scalar_metric.value + yield geo_id, time_str, channel_name, media_value + + return _extract_3d_data_array( + serialized_data_points=self._serialized.marketing_data_points, + data_extractor_fn=_media_extractor, + data_name=c.MEDIA, + third_dim_name=c.MEDIA_CHANNEL, + time_dim_name=c.MEDIA_TIME, + ) + + def _extract_reach(self) -> xr.DataArray | None: + """Extracts reach data from the serialized proto.""" + + def _reach_extractor(data_point): + geo_id, time_str = self._extract_geo_and_time(data_point) + + if not geo_id: + return None + + for rf_variable in data_point.reach_frequency_variables: + channel_name = rf_variable.channel_name + if self._metadata.channel_types.get(channel_name) != c.RF_CHANNEL: + continue + + reach_value = rf_variable.reach + yield geo_id, time_str, channel_name, reach_value + + return _extract_3d_data_array( + serialized_data_points=self._serialized.marketing_data_points, + data_extractor_fn=_reach_extractor, + data_name=c.REACH, + third_dim_name=c.RF_CHANNEL, + time_dim_name=c.MEDIA_TIME, + ) + + def _extract_frequency(self) -> xr.DataArray | None: + """Extracts frequency data from the serialized proto.""" + + def _frequency_extractor(data_point): + geo_id, time_str = self._extract_geo_and_time(data_point) + + if not geo_id: + return None + + for rf_variable in data_point.reach_frequency_variables: + channel_name = rf_variable.channel_name + if self._metadata.channel_types.get(channel_name) != c.RF_CHANNEL: + continue + + frequency_value = rf_variable.average_frequency + yield geo_id, time_str, channel_name, frequency_value + + return _extract_3d_data_array( + serialized_data_points=self._serialized.marketing_data_points, + data_extractor_fn=_frequency_extractor, + data_name=c.FREQUENCY, + third_dim_name=c.RF_CHANNEL, + time_dim_name=c.MEDIA_TIME, + ) + + def _extract_organic_media(self) -> xr.DataArray | None: + """Extracts organic media variables data from the serialized proto.""" + + def _organic_media_extractor(data_point): + geo_id, time_str = self._extract_geo_and_time(data_point) + + if not geo_id: + return None + + for media_variable in data_point.media_variables: + channel_name = media_variable.channel_name + if ( + self._metadata.channel_types.get(channel_name) + != c.ORGANIC_MEDIA_CHANNEL + ): + continue + + media_value = media_variable.scalar_metric.value + yield geo_id, time_str, channel_name, media_value + + return _extract_3d_data_array( + serialized_data_points=self._serialized.marketing_data_points, + data_extractor_fn=_organic_media_extractor, + data_name=c.ORGANIC_MEDIA, + third_dim_name=c.ORGANIC_MEDIA_CHANNEL, + time_dim_name=c.MEDIA_TIME, + ) + + def _extract_organic_reach(self) -> xr.DataArray | None: + """Extracts organic reach data from the serialized proto.""" + + def _organic_reach_extractor(data_point): + geo_id, time_str = self._extract_geo_and_time(data_point) + + if not geo_id: + return None + + for rf_variable in data_point.reach_frequency_variables: + channel_name = rf_variable.channel_name + if ( + self._metadata.channel_types.get(channel_name) + != c.ORGANIC_RF_CHANNEL + ): + continue + + reach_value = rf_variable.reach + yield geo_id, time_str, channel_name, reach_value + + return _extract_3d_data_array( + serialized_data_points=self._serialized.marketing_data_points, + data_extractor_fn=_organic_reach_extractor, + data_name=c.ORGANIC_REACH, + third_dim_name=c.ORGANIC_RF_CHANNEL, + time_dim_name=c.MEDIA_TIME, + ) + + def _extract_organic_frequency(self) -> xr.DataArray | None: + """Extracts organic frequency data from the serialized proto.""" + + def _organic_frequency_extractor(data_point): + geo_id, time_str = self._extract_geo_and_time(data_point) + + if not geo_id: + return None + + for rf_variable in data_point.reach_frequency_variables: + channel_name = rf_variable.channel_name + if ( + self._metadata.channel_types.get(channel_name) + != c.ORGANIC_RF_CHANNEL + ): + continue + + frequency_value = rf_variable.average_frequency + yield geo_id, time_str, channel_name, frequency_value + + return _extract_3d_data_array( + serialized_data_points=self._serialized.marketing_data_points, + data_extractor_fn=_organic_frequency_extractor, + data_name=c.ORGANIC_FREQUENCY, + third_dim_name=c.ORGANIC_RF_CHANNEL, + time_dim_name=c.MEDIA_TIME, + ) + + def _extract_granular_media_spend( + self, + data_points_with_spend: list[marketing_pb.MarketingDataPoint], + ) -> xr.DataArray | None: + """Extracts granular spend data. + + Args: + data_points_with_spend: List of MarketingDataPoint protos with spend data. + + Returns: + An xr.DataArray with granular spend data or None if no data found. + """ + + def _granular_spend_extractor(data_point): + geo_id, time_str = self._extract_geo_and_time(data_point) + for media_variable in data_point.media_variables: + if ( + media_variable.HasField(c.MEDIA_SPEND) + and self._metadata.channel_types.get(media_variable.channel_name) + == c.MEDIA_CHANNEL + ): + yield geo_id, time_str, media_variable.channel_name, media_variable.media_spend + + return _extract_3d_data_array( + serialized_data_points=data_points_with_spend, + data_extractor_fn=_granular_spend_extractor, + data_name=c.MEDIA_SPEND, + third_dim_name=c.MEDIA_CHANNEL, + time_dim_name=c.TIME, + ) + + def _extract_granular_rf_spend( + self, + data_points_with_spend: list[marketing_pb.MarketingDataPoint], + ) -> xr.DataArray | None: + """Extracts granular spend data. + + Args: + data_points_with_spend: List of MarketingDataPoint protos with spend data. + + Returns: + An xr.DataArray with granular spend data or None if no data found. + """ + + def _granular_spend_extractor(data_point): + geo_id, time_str = self._extract_geo_and_time(data_point) + for rf_variable in data_point.reach_frequency_variables: + if ( + rf_variable.HasField(c.SPEND) + and self._metadata.channel_types.get(rf_variable.channel_name) + == c.RF_CHANNEL + ): + yield geo_id, time_str, rf_variable.channel_name, rf_variable.spend + + return _extract_3d_data_array( + serialized_data_points=data_points_with_spend, + data_extractor_fn=_granular_spend_extractor, + data_name=c.RF_SPEND, + third_dim_name=c.RF_CHANNEL, + time_dim_name=c.TIME, + ) + + def _extract_aggregated_media_spend( + self, + data_points_with_spend: list[marketing_pb.MarketingDataPoint], + ) -> xr.DataArray | None: + """Extracts aggregated spend data. + + Args: + data_points_with_spend: List of MarketingDataPoint protos with spend data. + + Returns: + An xr.DataArray with aggregated spend data or None if no data found. + """ + channel_names = self._metadata.channel_dimensions.get(c.MEDIA, []) + channel_spend_map = {} + + for spend_data_point in data_points_with_spend: + for media_variable in spend_data_point.media_variables: + if ( + media_variable.channel_name in channel_names + and media_variable.HasField(c.MEDIA_SPEND) + ): + channel_spend_map[media_variable.channel_name] = ( + media_variable.media_spend + ) + + if not channel_spend_map: + return None + + return xr.DataArray( + data=list(channel_spend_map.values()), + coords={c.MEDIA_CHANNEL: list(channel_spend_map.keys())}, + dims=[c.MEDIA_CHANNEL], + name=c.MEDIA_SPEND, + ) + + def _extract_aggregated_rf_spend( + self, + data_points_with_spend: list[marketing_pb.MarketingDataPoint], + ) -> xr.DataArray | None: + """Extracts aggregated spend data. + + Args: + data_points_with_spend: List of MarketingDataPoint protos with spend data. + + Returns: + An xr.DataArray with aggregated spend data or None if no data found. + """ + channel_names = self._metadata.channel_dimensions.get(c.REACH, []) + channel_spend_map = {} + + for spend_data_point in data_points_with_spend: + for rf_variable in spend_data_point.reach_frequency_variables: + if rf_variable.channel_name in channel_names and rf_variable.HasField( + c.SPEND + ): + channel_spend_map[rf_variable.channel_name] = rf_variable.spend + + if not channel_spend_map: + return None + + return xr.DataArray( + data=list(channel_spend_map.values()), + coords={c.RF_CHANNEL: list(channel_spend_map.keys())}, + dims=[c.RF_CHANNEL], + name=c.RF_SPEND, + ) + + def _is_aggregated_spend_data_point( + self, dp: marketing_pb.MarketingDataPoint + ) -> bool: + """Checks if a MarketingDataPoint with spend represents aggregated spend data. + + Args: + dp: A marketing_pb.MarketingDataPoint representing a spend data point. + + Returns: + True if the data point represents aggregated spend, False otherwise. + """ + if not dp.HasField(sc.GEO_INFO) and self._metadata.media_time_dimension: + media_time_interval = ( + self._metadata.media_time_dimension.time_dimension_interval + ) + return ( + media_time_interval.start_date == dp.date_interval.start_date + and media_time_interval.end_date == dp.date_interval.end_date + ) + return False + + def _extract_media_spend(self) -> xr.DataArray | None: + """Extracts media spend data from the serialized proto. + + Returns: + An xr.DataArray with spend data or None if no data found. + """ + # Filter data points relevant to spend based on channel type map + media_channels = { + channel + for channel, metadata_channel_type in self._metadata.channel_types.items() + if metadata_channel_type == c.MEDIA_CHANNEL + } + spend_data_points = [ + dp + for dp in self._serialized.marketing_data_points + if any( + mv.HasField(c.MEDIA_SPEND) and mv.channel_name in media_channels + for mv in dp.media_variables + ) + ] + + if not spend_data_points: + return None + + aggregated_spend_data_points = [ + dp + for dp in spend_data_points + if self._is_aggregated_spend_data_point(dp) + ] + + if aggregated_spend_data_points: + return self._extract_aggregated_media_spend(aggregated_spend_data_points) + + return self._extract_granular_media_spend(spend_data_points) + + def _extract_rf_spend(self) -> xr.DataArray | None: + """Extracts reach and frequency spend data from the serialized proto. + + Returns: + An xr.DataArray with spend data or None if no data found. + """ + # Filter data points relevant to spend based on channel type map + rf_channels = { + channel + for channel, metadata_channel_type in self._metadata.channel_types.items() + if metadata_channel_type == c.RF_CHANNEL + } + spend_data_points = [ + dp + for dp in self._serialized.marketing_data_points + if any( + mv.HasField(c.SPEND) and mv.channel_name in rf_channels + for mv in dp.reach_frequency_variables + ) + ] + + if not spend_data_points: + return None + + aggregated_spend_data_points = [ + dp + for dp in spend_data_points + if self._is_aggregated_spend_data_point(dp) + ] + + if aggregated_spend_data_points: + return self._extract_aggregated_rf_spend(aggregated_spend_data_points) + + return self._extract_granular_rf_spend(spend_data_points) + + def _extract_non_media_treatments(self) -> xr.DataArray | None: + """Extracts non-media treatment variables data from the serialized proto.""" + + def _non_media_treatments_extractor(data_point): + if not data_point.non_media_treatment_variables: + return None + + geo_id, time_str = self._extract_geo_and_time(data_point) + + for ( + non_media_treatment_variable + ) in data_point.non_media_treatment_variables: + treatment_name = non_media_treatment_variable.name + treatment_value = non_media_treatment_variable.value + yield geo_id, time_str, treatment_name, treatment_value + + non_media_treatments_data_array = _extract_3d_data_array( + serialized_data_points=self._serialized.marketing_data_points, + data_extractor_fn=_non_media_treatments_extractor, + data_name=c.NON_MEDIA_TREATMENTS, + third_dim_name=c.NON_MEDIA_CHANNEL, + ) + + return non_media_treatments_data_array + + def __call__(self) -> meridian_input_data.InputData: + """Converts the `MarketingData` proto to a Meridian `InputData`.""" + kpi_type = self._extract_kpi_type() + return meridian_input_data.InputData( + kpi=self._extract_kpi(kpi_type), + kpi_type=kpi_type, + controls=self._extract_controls(), + population=self._extract_population(), + revenue_per_kpi=( + self._extract_revenue_per_kpi(kpi_type) + if kpi_type == c.NON_REVENUE + else None + ), + media=self._extract_media(), + media_spend=self._extract_media_spend(), + reach=self._extract_reach(), + frequency=self._extract_frequency(), + rf_spend=self._extract_rf_spend(), + organic_media=self._extract_organic_media(), + organic_reach=self._extract_organic_reach(), + organic_frequency=self._extract_organic_frequency(), + non_media_treatments=self._extract_non_media_treatments(), + ) + + +class MarketingDataSerde( + serde.Serde[marketing_pb.MarketingData, meridian_input_data.InputData] +): + """Serializes and deserializes an `InputData` container in Meridian.""" + + def serialize( + self, obj: meridian_input_data.InputData + ) -> marketing_pb.MarketingData: + """Serializes the given Meridian input data into a `MarketingData` proto.""" + return _InputDataSerializer(obj)() + + def deserialize( + self, serialized: marketing_pb.MarketingData, serialized_version: str = "" + ) -> meridian_input_data.InputData: + """Deserializes the given `MarketingData` proto. + + Args: + serialized: The serialized `MarketingData` proto. + serialized_version: The version of the serialized model. + + Returns: + A Meridian input data container. + """ + return _InputDataDeserializer(serialized)() diff --git a/schema/serde/marketing_data_test.py b/schema/serde/marketing_data_test.py new file mode 100644 index 000000000..72ba39bda --- /dev/null +++ b/schema/serde/marketing_data_test.py @@ -0,0 +1,1093 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import arviz as az +from meridian import backend +from meridian import constants as c +from meridian.analysis import analyzer +from meridian.analysis import visualizer +from meridian.model import media +from meridian.model import model +from meridian.model import spec +from mmm.v1.marketing import marketing_data_pb2 as marketing_pb +from schema.serde import marketing_data +from schema.serde import test_data +import numpy as np +import xarray.testing as xrt + +from tensorflow.python.util.protobuf import compare +from google.protobuf import text_format + + +class MarketingDataTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + + self._mock_analyzer = self.enter_context( + mock.patch.object(analyzer, 'Analyzer', autospec=True) + ) + self._mock_visualizer = self.enter_context( + mock.patch.object(visualizer, 'ModelDiagnostics', autospec=True) + ) + + self.serde = marketing_data.MarketingDataSerde() + + def _mock_meridian(self) -> mock.MagicMock: + """Creates a mock MMM object with InferenceData based on given flags. + + Returns: + A mock MMM object with InferenceData. + """ + return mock.MagicMock( + spec=model.Meridian, + controls_scaled=backend.to_tensor( + np.full((2, 3), 5.0), dtype=backend.float32 + ), + kpi_scaled=backend.to_tensor(np.full((4,), 6.0), dtype=backend.float32), + media_tensors=media.MediaTensors(), + rf_tensors=media.RfTensors(), + inference_data=az.InferenceData(), + model_spec=spec.ModelSpec(), + ) + + def _setup_meridian(self): + self._mock_meridian = self._mock_meridian() + + @parameterized.named_parameters( + dict( + testcase_name='national_media_and_rf_non_revenue', + input_data=test_data.MOCK_INPUT_DATA_NATIONAL_MEDIA_RF_NON_REVENUE, + expected_proto=test_data.MOCK_PROTO_NATIONAL_MEDIA_RF_NON_REVENUE, + n_geos=1, + n_times=2, + ), + dict( + testcase_name='national_media_and_rf_non_revenue_no_controls', + input_data=test_data.MOCK_INPUT_DATA_NATIONAL_MEDIA_RF_NON_REVENUE_NO_CONTROLS, + expected_proto=test_data.MOCK_PROTO_NATIONAL_MEDIA_RF_NON_REVENUE_NO_CONTROLS, + n_geos=1, + n_times=2, + ), + dict( + testcase_name='media_paid_expanded_lagged', + input_data=test_data.MOCK_INPUT_DATA_MEDIA_PAID_EXPANDED_LAGGED, + expected_proto=test_data.MOCK_PROTO_MEDIA_PAID_EXPANDED_LAGGED, + n_geos=2, + n_times=2, + ), + dict( + testcase_name='media_paid_granular_not_lagged', + input_data=test_data.MOCK_INPUT_DATA_MEDIA_PAID_GRANULAR_NOT_LAGGED, + expected_proto=test_data.MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED, + n_geos=1, + n_times=2, + ), + dict( + testcase_name='rf_paid_expanded_lagged', + input_data=test_data.MOCK_INPUT_DATA_RF_PAID_EXPANDED_LAGGED, + expected_proto=test_data.MOCK_PROTO_RF_PAID_EXPANDED_LAGGED, + n_geos=2, + n_times=2, + ), + dict( + testcase_name='rf_paid_granular_not_lagged', + input_data=test_data.MOCK_INPUT_DATA_RF_PAID_GRANULAR_NOT_LAGGED, + expected_proto=test_data.MOCK_PROTO_RF_PAID_GRANULAR_NOT_LAGGED, + n_geos=1, + n_times=2, + ), + dict( + testcase_name='media_organic_expanded_lagged', + input_data=test_data.MOCK_INPUT_DATA_MEDIA_ORGANIC_EXPANDED_LAGGED, + expected_proto=test_data.MOCK_PROTO_MEDIA_ORGANIC_EXPANDED_LAGGED, + n_geos=2, + n_times=2, + ), + dict( + testcase_name='media_organic_granular_not_lagged', + input_data=test_data.MOCK_INPUT_DATA_MEDIA_ORGANIC_GRANULAR_NOT_LAGGED, + expected_proto=test_data.MOCK_PROTO_MEDIA_ORGANIC_GRANULAR_NOT_LAGGED, + n_geos=1, + n_times=2, + ), + dict( + testcase_name='rf_organic_expanded_lagged', + input_data=test_data.MOCK_INPUT_DATA_RF_ORGANIC_EXPANDED_LAGGED, + expected_proto=test_data.MOCK_PROTO_RF_ORGANIC_EXPANDED_LAGGED, + n_geos=2, + n_times=2, + ), + dict( + testcase_name='rf_organic_granular_not_lagged', + input_data=test_data.MOCK_INPUT_DATA_RF_ORGANIC_GRANULAR_NOT_LAGGED, + expected_proto=test_data.MOCK_PROTO_RF_ORGANIC_GRANULAR_NOT_LAGGED, + n_geos=1, + n_times=2, + ), + dict( + testcase_name='non_media_treatments', + input_data=test_data.MOCK_INPUT_DATA_NON_MEDIA_TREATMENTS, + expected_proto=test_data.MOCK_PROTO_NON_MEDIA_TREATMENTS, + n_geos=2, + n_times=2, + ), + dict( + testcase_name='no_revenue_per_kpi', + input_data=test_data.MOCK_INPUT_DATA_NO_REVENUE_PER_KPI, + expected_proto=test_data.MOCK_PROTO_NO_REVENUE_PER_KPI, + n_geos=2, + n_times=2, + ), + ) + def test_serialize_marketing_data( + self, input_data, expected_proto, n_geos, n_times + ): + self._setup_meridian() + self._mock_meridian.n_geos = n_geos + self._mock_meridian.n_times = n_times + self._mock_meridian.input_data = input_data + + actual = self.serde.serialize(input_data) + + compare.assertProtoEqual(self, expected_proto, actual) + + def test_serialize_metadata_unknown_channel_data_name(self): + input_data = test_data.MOCK_INPUT_DATA_MEDIA_PAID_EXPANDED_LAGGED + unknown_channel_name = 'unknown_channel' + input_data.media.name = unknown_channel_name + + with self.assertRaisesRegex( + ValueError, f'Unknown channel data name: {unknown_channel_name}.' + ): + self.serde.serialize(input_data) + + @parameterized.named_parameters( + dict( + testcase_name='national_media_and_rf_non_revenue_no_controls', + marketing_data_proto=test_data.MOCK_PROTO_NATIONAL_MEDIA_RF_NON_REVENUE_NO_CONTROLS, + expected_input_data=test_data.MOCK_INPUT_DATA_NATIONAL_MEDIA_RF_NON_REVENUE_NO_CONTROLS, + ), + dict( + testcase_name='media_and_rf_non_revenue', + marketing_data_proto=test_data.MOCK_PROTO_NATIONAL_MEDIA_RF_NON_REVENUE, + expected_input_data=test_data.MOCK_INPUT_DATA_NATIONAL_MEDIA_RF_NON_REVENUE, + ), + dict( + testcase_name='media_paid_expanded_lagged', + marketing_data_proto=test_data.MOCK_PROTO_MEDIA_PAID_EXPANDED_LAGGED, + expected_input_data=test_data.MOCK_INPUT_DATA_MEDIA_PAID_EXPANDED_LAGGED, + ), + dict( + testcase_name='media_paid_granular_not_lagged', + marketing_data_proto=test_data.MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED, + expected_input_data=test_data.MOCK_INPUT_DATA_MEDIA_PAID_GRANULAR_NOT_LAGGED, + ), + dict( + testcase_name='rf_paid_expanded_lagged', + marketing_data_proto=test_data.MOCK_PROTO_RF_PAID_EXPANDED_LAGGED, + expected_input_data=test_data.MOCK_INPUT_DATA_RF_PAID_EXPANDED_LAGGED, + ), + dict( + testcase_name='rf_paid_granular_not_lagged', + marketing_data_proto=test_data.MOCK_PROTO_RF_PAID_GRANULAR_NOT_LAGGED, + expected_input_data=test_data.MOCK_INPUT_DATA_RF_PAID_GRANULAR_NOT_LAGGED, + ), + dict( + testcase_name='media_organic_expanded_lagged', + marketing_data_proto=test_data.MOCK_PROTO_MEDIA_ORGANIC_EXPANDED_LAGGED, + expected_input_data=test_data.MOCK_INPUT_DATA_MEDIA_ORGANIC_EXPANDED_LAGGED, + ), + dict( + testcase_name='media_organic_granular_not_lagged', + marketing_data_proto=test_data.MOCK_PROTO_MEDIA_ORGANIC_GRANULAR_NOT_LAGGED, + expected_input_data=test_data.MOCK_INPUT_DATA_MEDIA_ORGANIC_GRANULAR_NOT_LAGGED, + ), + dict( + testcase_name='rf_organic_expanded_lagged', + marketing_data_proto=test_data.MOCK_PROTO_RF_ORGANIC_EXPANDED_LAGGED, + expected_input_data=test_data.MOCK_INPUT_DATA_RF_ORGANIC_EXPANDED_LAGGED, + ), + dict( + testcase_name='rf_organic_granular_not_lagged', + marketing_data_proto=test_data.MOCK_PROTO_RF_ORGANIC_GRANULAR_NOT_LAGGED, + expected_input_data=test_data.MOCK_INPUT_DATA_RF_ORGANIC_GRANULAR_NOT_LAGGED, + ), + dict( + testcase_name='non_media_treatments', + marketing_data_proto=test_data.MOCK_PROTO_NON_MEDIA_TREATMENTS, + expected_input_data=test_data.MOCK_INPUT_DATA_NON_MEDIA_TREATMENTS, + ), + dict( + testcase_name='no_revenue_per_kpi', + marketing_data_proto=test_data.MOCK_PROTO_NO_REVENUE_PER_KPI, + expected_input_data=test_data.MOCK_INPUT_DATA_NO_REVENUE_PER_KPI, + ), + ) + def test_deserialize_marketing_data_proto( + self, marketing_data_proto, expected_input_data + ): + deserialized_data = self.serde.deserialize(marketing_data_proto) + xrt.assert_allclose( + deserialized_data.population, + expected_input_data.population, + atol=0.5, + rtol=0, + ) + self.assertEqual(deserialized_data.kpi_type, expected_input_data.kpi_type) + xrt.assert_allclose( + deserialized_data.kpi, + expected_input_data.kpi, + ) + if expected_input_data.revenue_per_kpi is None: + self.assertIsNone( + deserialized_data.revenue_per_kpi, + 'Expected revenue_per_kpi to be None', + ) + else: + xrt.assert_allclose( + deserialized_data.revenue_per_kpi, + expected_input_data.revenue_per_kpi, + ) + if expected_input_data.controls is None: + self.assertIsNone( + deserialized_data.controls, + 'Expected controls to be None', + ) + else: + xrt.assert_allclose( + deserialized_data.controls, + expected_input_data.controls, + ) + if expected_input_data.media is None: + self.assertIsNone(deserialized_data.media, 'Expected media to be None') + else: + xrt.assert_allclose( + deserialized_data.media, + expected_input_data.media, + ) + + if expected_input_data.media_spend is None: + self.assertIsNone( + deserialized_data.media, 'Expected media_spend to be None' + ) + else: + xrt.assert_allclose( + deserialized_data.media_spend, + expected_input_data.media_spend, + ) + + if expected_input_data.reach is None: + self.assertIsNone(deserialized_data.reach, 'Expected reach to be None') + else: + xrt.assert_allclose(deserialized_data.reach, expected_input_data.reach) + + if expected_input_data.frequency is None: + self.assertIsNone( + deserialized_data.frequency, 'Expected frequency to be None' + ) + else: + xrt.assert_allclose( + deserialized_data.frequency, expected_input_data.frequency + ) + + if expected_input_data.rf_spend is None: + self.assertIsNone( + deserialized_data.rf_spend, 'Expected rf_spend to be None' + ) + else: + xrt.assert_allclose( + deserialized_data.rf_spend, + expected_input_data.rf_spend, + ) + + if expected_input_data.organic_media is None: + self.assertIsNone( + deserialized_data.organic_media, 'Expected organic_media to be None' + ) + else: + xrt.assert_allclose( + deserialized_data.organic_media, expected_input_data.organic_media + ) + + if expected_input_data.organic_reach is None: + self.assertIsNone( + deserialized_data.organic_reach, 'Expected organic_reach to be None' + ) + else: + xrt.assert_allclose( + deserialized_data.organic_reach, expected_input_data.organic_reach + ) + + if expected_input_data.organic_frequency is None: + self.assertIsNone( + deserialized_data.organic_frequency, + 'Expected organic_frequency to be None', + ) + else: + xrt.assert_allclose( + deserialized_data.organic_frequency, + expected_input_data.organic_frequency, + ) + + if expected_input_data.non_media_treatments is None: + self.assertIsNone( + deserialized_data.non_media_treatments, + 'Expected non_media_treatments to be None', + ) + else: + xrt.assert_allclose( + deserialized_data.non_media_treatments, + expected_input_data.non_media_treatments, + ) + + @parameterized.named_parameters( + dict( + testcase_name='inconsistent_kpi_type', + marketing_data_proto=text_format.Parse( + """ + marketing_data_points { + geo_info { geo_id: "geo_0" } + date_interval { + start_date { year: 2023 month: 1 day: 1 } + end_date { year: 2023 month: 1 day: 8 } + } + kpi { revenue { value: 10 } } + } + marketing_data_points { + geo_info { geo_id: "geo_1" } + date_interval { + start_date { year: 2023 month: 1 day: 8 } + end_date { year: 2023 month: 1 day: 15 } + } + kpi { non_revenue { value: 5 } } + } + """, + marketing_pb.MarketingData(), + ), + expected_error_message='Inconsistent kpi_type found in the data.', + ), + dict( + testcase_name='missing_kpi_type', + marketing_data_proto=text_format.Parse( + """ + marketing_data_points { + geo_info { geo_id: "geo_0" } + date_interval { + start_date { year: 2023 month: 1 day: 1 } + end_date { year: 2023 month: 1 day: 8 } + } + } + """, + marketing_pb.MarketingData(), + ), + expected_error_message='kpi_type not found in the data.', + ), + ) + def test_extract_kpi_type_errors( + self, marketing_data_proto, expected_error_message + ): + with self.assertRaisesRegex(ValueError, expected_error_message): + self.serde.deserialize(marketing_data_proto) + + def test_extract_controls_missing_data_error(self): + marketing_data_proto = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + media_spend: 123.0 + } + kpi { + name: "revenue" + revenue { + value: 1.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 40.0 + } + media_spend: 125.0 + } + kpi { + name: "revenue" + revenue { + value: 1.1 + } + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + } + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), + ) + + input_data = self.serde.deserialize(marketing_data_proto) + self.assertIsNone(input_data.controls) + + def test_deserialize_skip_non_rf_channel_in_extract_frequency(self): + marketing_data_proto = text_format.Parse( + """ + metadata { + time_dimensions { + name: "time" + dates { year: 2023 month: 1 day: 1 } + dates { year: 2023 month: 1 day: 8 } + } + time_dimensions { + name: "media_time" + dates { year: 2023 month: 1 day: 1 } + dates { year: 2023 month: 1 day: 8 } + } + channel_dimensions { + name: "media" + channels: "media_channel1" + } + channel_dimensions { + name: "reach" + channels: "rf_channel1" + } + channel_dimensions { + name: "frequency" + channels: "rf_channel1" + } + } + marketing_data_points { + geo_info { geo_id: "geo_0" } + date_interval { + start_date { year: 2023 month: 1 day: 1 } + end_date { year: 2023 month: 1 day: 8 } + } + kpi { non_revenue { value: 10 } } + control_variables { + name: "control_0" + value: 31.0 + } + media_variables { + channel_name: "media_channel1" + } + reach_frequency_variables { + channel_name: "media_channel1" + reach: 1 + average_frequency: 2 + } + } + marketing_data_points { + geo_info { geo_id: "geo_0" } + date_interval { + start_date { year: 2023 month: 1 day: 8 } + end_date { year: 2023 month: 1 day: 15 } + } + kpi { non_revenue { value: 10 } } + control_variables { + name: "control_0" + value: 31.0 + } + media_variables { + channel_name: "media_channel1" + } + reach_frequency_variables { + channel_name: "media_channel1" + reach: 1 + average_frequency: 2 + } + } + """, + marketing_pb.MarketingData(), + ) + deserialized_data = self.serde.deserialize(marketing_data_proto) + self.assertIsNone(deserialized_data.frequency) + + def test_deserialize_time_dimension_with_no_dates(self): + marketing_data_proto = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + media_spend: 123.0 + } + kpi { + name: "revenue" + revenue { + value: 1.0 + } + } + } + metadata { + time_dimensions { + name: "time" + } + time_dimensions { + name: "media_time" + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + } + control_names: "control_0" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), + ) + with self.assertRaisesRegex( + ValueError, 'TimeDimension proto must have at least one date.' + ): + self.serde.deserialize(marketing_data_proto) + + def test_deserialize_aggregated_spend_incorrect_start_date_interval(self): + # Create a MarketingData proto with an incorrectly defined aggregated + # spend data point (no geo_info, but also wrong date interval). + marketing_data_proto = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + } + kpi { + name: "revenue" + revenue { + value: 1.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + } + kpi { + name: "revenue" + revenue { + value: 1.0 + } + } + } + marketing_data_points { + date_interval { + start_date { + year: 2021 + month: 1 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 22 + } + } + media_variables { + channel_name: "ch_paid_0" + media_spend: 123.0 + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + } + control_names: "control_0" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), + ) + + deserializer = marketing_data.MarketingDataSerde() + deserialized_input_data = deserializer.deserialize(marketing_data_proto) + + # This should be granular since the date interval doesn't match. + self.assertEqual( + deserialized_input_data.media_spend.sizes, + { + c.GEO: 1, + c.TIME: 1, + c.MEDIA_CHANNEL: 1, + }, + 'media_spend should have dimensions (geo=1, time=1, media_channel=1)' + ' when treated as granular', + ) + + def test_deserialize_aggregated_spend_incorrect_end_date_interval(self): + # Create a MarketingData proto with an incorrectly defined aggregated + # spend data point (no geo_info, but also wrong date interval). + marketing_data_proto = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + } + kpi { + name: "revenue" + revenue { + value: 1.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + } + kpi { + name: "revenue" + revenue { + value: 1.0 + } + } + } + marketing_data_points { + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 16 + } + } + media_variables { + channel_name: "ch_paid_0" + media_spend: 123.0 + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + } + control_names: "control_0" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), + ) + + deserializer = marketing_data.MarketingDataSerde() + deserialized_input_data = deserializer.deserialize(marketing_data_proto) + + # This should be granular since the date interval doesn't match. + self.assertEqual( + deserialized_input_data.media_spend.sizes, + { + c.GEO: 1, + c.TIME: 1, + c.MEDIA_CHANNEL: 1, + }, + 'media_spend should have dimensions (geo=1, time=1, media_channel=1)' + ' when treated as granular', + ) + + def test_deserialize_aggregated_spend_incorrect_geo_info(self): + # Create a MarketingData proto with an incorrectly defined aggregated + # spend data point (geo_info set, and correct date interval). + marketing_data_proto = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + } + kpi { + name: "revenue" + revenue { + value: 1.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + } + kpi { + name: "revenue" + revenue { + value: 1.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + media_variables { + channel_name: "ch_paid_0" + media_spend: 123.0 + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + } + control_names: "control_0" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), + ) + + deserializer = marketing_data.MarketingDataSerde() + deserialized_input_data = deserializer.deserialize(marketing_data_proto) + + # This should be granular since the geo_info is set. + self.assertEqual( + deserialized_input_data.media_spend.sizes, + { + c.GEO: 1, + c.TIME: 1, + c.MEDIA_CHANNEL: 1, + }, + 'media_spend should have dimensions (geo=1, time=1, media_channel=1)' + ' when treated as granular', + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/schema/serde/meridian_serde.py b/schema/serde/meridian_serde.py new file mode 100644 index 000000000..6c05efb4b --- /dev/null +++ b/schema/serde/meridian_serde.py @@ -0,0 +1,350 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Serialization and deserialization of Meridian models into/from proto format. + +The `meridian_serde.MeridianSerde` class provides an interface for serializing +and deserializing Meridian models into and from an `MmmKernel` proto message. + +The Meridian model--when serialized into an `MmmKernel` proto--is internally +represented as the sum of the following components: + +1. Marketing data: This includes the KPI, media, and control data present in + the input data. They are structured into an MMM-agnostic `MarketingData` + proto message. +2. Meridian model: A `MeridianModel` proto message encapsulates + Meridian-specific model parameters, including hyperparameters, prior + distributions, and sampled inference data. + +Sample usage: + +```python +from schema.serde import meridian_serde + +serde = meridian_serde.MeridianSerde() +mmm = model.Meridian(...) +serialized_mmm = serde.serialize(mmm) # An `MmmKernel` proto +deserialized_mmm = serde.deserialize(serialized_mmm) # A `Meridian` object +``` +""" + +import dataclasses +import os +from typing import Any, Callable + +from google.protobuf import text_format +import meridian +from meridian import backend +from meridian.analysis import analyzer +from meridian.analysis import visualizer +from meridian.model import model +from mmm.v1.model import mmm_kernel_pb2 as kernel_pb +from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb +from schema.serde import distribution +from schema.serde import hyperparameters +from schema.serde import inference_data +from schema.serde import marketing_data +from schema.serde import serde +import semver + +from google.protobuf import any_pb2 + + +_VERSION_INFO = semver.VersionInfo.parse(meridian.__version__) + +FunctionRegistry = dict[str, Callable[..., Any]] + +_file_exists = os.path.exists +_make_dirs = os.makedirs +_file_open = open + + +class MeridianSerde(serde.Serde[kernel_pb.MmmKernel, model.Meridian]): + """Serializes and deserializes a Meridian model into an `MmmKernel` proto.""" + + def serialize( + self, + obj: model.Meridian, + model_id: str = '', + meridian_version: semver.VersionInfo = _VERSION_INFO, + include_convergence_info: bool = False, + function_registry: FunctionRegistry | None = None, + ) -> kernel_pb.MmmKernel: + """Serializes the given Meridian model into an `MmmKernel` proto. + + Args: + obj: The Meridian model to serialize. + model_id: The ID of the model. + meridian_version: The version of the Meridian model. + include_convergence_info: Whether to include convergence information. + function_registry: Optional. A lookup table that maps string keys to + custom functions to be used as parameters in various + `tfp.distributions`. + + Returns: + An `MmmKernel` proto representing the serialized model. + """ + meridian_model_proto = self._make_meridian_model_proto( + obj, + model_id, + meridian_version, + include_convergence_info, + function_registry, + ) + any_model = any_pb2.Any() + any_model.Pack(meridian_model_proto) + return kernel_pb.MmmKernel( + marketing_data=marketing_data.MarketingDataSerde().serialize( + obj.input_data + ), + model=any_model, + ) + + def _make_meridian_model_proto( + self, + mmm: model.Meridian, + model_id: str, + meridian_version: semver.VersionInfo, + include_convergence_info: bool = False, + function_registry: FunctionRegistry | None = None, + ) -> meridian_pb.MeridianModel: + """Constructs a MeridianModel proto from the TrainedModel. + + Args: + mmm: Meridian model. + model_id: The ID of the model. + meridian_version: The version of the Meridian model. + include_convergence_info: Whether to include convergence information. + function_registry: Optional. A lookup table that maps string keys to + custom functions to be used as parameters in various + `tfp.distributions`. + + Returns: + A MeridianModel proto. + """ + + model_proto = meridian_pb.MeridianModel( + model_id=model_id, + model_version=str(meridian_version), + hyperparameters=hyperparameters.HyperparametersSerde().serialize( + mmm.model_spec + ), + prior_tfp_distributions=distribution.DistributionSerde( + function_registry + ).serialize(mmm.model_spec.prior), + inference_data=inference_data.InferenceDataSerde().serialize( + mmm.inference_data + ), + kpi_scaled=backend.make_tensor_proto(mmm.kpi_scaled), + ) + + if mmm.controls_scaled is not None: + model_proto.controls_scaled.CopyFrom( + backend.make_tensor_proto(mmm.controls_scaled) + ) + + media_tensors = mmm.media_tensors + rf_tensors = mmm.rf_tensors + if media_tensors.media_scaled is not None: + model_proto.media_scaled.CopyFrom( + backend.make_tensor_proto(media_tensors.media_scaled) + ) + if rf_tensors.reach_scaled is not None: + model_proto.reach_scaled.CopyFrom( + backend.make_tensor_proto(rf_tensors.reach_scaled) + ) + + if include_convergence_info: + convergence_proto = self._make_model_convergence_proto(mmm) + if convergence_proto is not None: + model_proto.convergence_info.CopyFrom(convergence_proto) + + return model_proto + + def _make_model_convergence_proto( + self, mmm: model.Meridian + ) -> meridian_pb.ModelConvergence | None: + """Creates ModelConvergence proto.""" + model_convergence_proto = meridian_pb.ModelConvergence() + try: + # NotFittedModelError can be raised below. If raised, + # return None. Otherwise, set convergence status based on + # MCMCSamplingError (caught in the except block). + rhats = analyzer.Analyzer(mmm).get_rhat() + rhat_proto = meridian_pb.RHatDiagnostic() + for name, tensor in rhats.items(): + rhat_proto.parameter_r_hats.add( + name=name, tensor=backend.make_tensor_proto(tensor) + ) + model_convergence_proto.r_hat_diagnostic.CopyFrom(rhat_proto) + + visualizer.ModelDiagnostics(mmm).plot_rhat_boxplot() + model_convergence_proto.convergence = True + except model.MCMCSamplingError: + model_convergence_proto.convergence = False + except model.NotFittedModelError: + return None + + if hasattr(mmm.inference_data, 'trace'): + trace = mmm.inference_data.trace + mcmc_sampling_trace = meridian_pb.McmcSamplingTrace( + num_chains=len(trace.chain), + num_draws=len(trace.draw), + step_size=backend.make_tensor_proto(trace.step_size), + tune=backend.make_tensor_proto(trace.tune), + target_log_prob=backend.make_tensor_proto(trace.target_log_prob), + diverging=backend.make_tensor_proto(trace.diverging), + accept_ratio=backend.make_tensor_proto(trace.accept_ratio), + n_steps=backend.make_tensor_proto(trace.n_steps), + is_accepted=backend.make_tensor_proto(trace.is_accepted), + ) + model_convergence_proto.mcmc_sampling_trace.CopyFrom(mcmc_sampling_trace) + + return model_convergence_proto + + def deserialize( + self, + serialized: kernel_pb.MmmKernel, + serialized_version: str = '', + function_registry: FunctionRegistry | None = None, + force_deserialization=False, + ) -> model.Meridian: + """Deserializes the given `MmmKernel` proto into a Meridian model.""" + if serialized.model.Is(meridian_pb.MeridianModel.DESCRIPTOR): + ser_meridian = meridian_pb.MeridianModel() + else: + raise ValueError('`serialized.model` is not a `MeridianModel`.') + serialized.model.Unpack(ser_meridian) + serialized_version = semver.VersionInfo.parse(ser_meridian.model_version) + + deserialized_hyperparameters = ( + hyperparameters.HyperparametersSerde().deserialize( + ser_meridian.hyperparameters, str(serialized_version) + ) + ) + + if ser_meridian.HasField('prior_distributions'): + ser_meridian_priors = ser_meridian.prior_distributions + elif ser_meridian.HasField('prior_tfp_distributions') and isinstance( + ser_meridian, meridian_pb.MeridianModel + ): + ser_meridian_priors = ser_meridian.prior_tfp_distributions + else: + raise ValueError('MeridianModel does not contain any priors.') + + deserialized_prior_distributions = distribution.DistributionSerde( + function_registry + ).deserialize( + ser_meridian_priors, + str(serialized_version), + force_deserialization=force_deserialization, + ) + deserialized_marketing_data = ( + marketing_data.MarketingDataSerde().deserialize( + serialized.marketing_data, str(serialized_version) + ) + ) + deserialized_inference_data = ( + inference_data.InferenceDataSerde().deserialize( + ser_meridian.inference_data, str(serialized_version) + ) + ) + + deserialized_model_spec = dataclasses.replace( + deserialized_hyperparameters, prior=deserialized_prior_distributions + ) + + return model.Meridian( + input_data=deserialized_marketing_data, + model_spec=deserialized_model_spec, + inference_data=deserialized_inference_data, + ) + + +def save_meridian( + mmm: model.Meridian, + file_path: str, + function_registry: FunctionRegistry | None = None, +): + """Save the model object as an `MmmKernel` proto in the given filepath. + + Supported file types: + - `binpb` (wire-format proto) + - `txtpb` (text-format proto) + - `textproto` (text-format proto) + + Args: + mmm: Model object to save. + file_path: File path to save a serialized model object. If the file name + ends with `.binpb`, it will be saved in the wire-format. If the filename + ends with `.txtpb` or `.textproto`, it will be saved in the text-format. + function_registry: Optional. A lookup table that maps string keys to custom + functions to be used as parameters in various `tfp.distributions`. + """ + if not _file_exists(os.path.dirname(file_path)): + _make_dirs(os.path.dirname(file_path)) + + with _file_open(file_path, 'wb') as f: + # Creates an MmmKernel. + serialized_kernel = MeridianSerde().serialize( + mmm, function_registry=function_registry + ) + if file_path.endswith('.binpb'): + f.write(serialized_kernel.SerializeToString()) + elif file_path.endswith('.textproto') or file_path.endswith('.txtpb'): + f.write(text_format.MessageToString(serialized_kernel)) + else: + raise ValueError(f'Unsupported file type: {file_path}') + + +def load_meridian( + file_path: str, + function_registry: FunctionRegistry | None = None, + force_deserialization=False, +) -> model.Meridian: + """Load the model object from an `MmmKernel` proto file path. + + Supported file types: + - `binpb` (wire-format proto) + - `txtpb` (text-format proto) + - `textproto` (text-format proto) + + Args: + file_path: File path to load a serialized model object from. + function_registry: A lookup table that maps string keys to custom functions + to be used as parameters in various `tfp.distributions`. + force_deserialization: If True, bypasses the safety check that validates + whether functions within `function_registry` have changed after + serialization. Use with caution. This should only be used if you have + intentionally modified a custom function and are confident that the + changes will not affect the deserialized model. A safer alternative is to + first deserialize the model with the original functions and then serialize + it with the new ones. + + Returns: + Model object loaded from the file path. + """ + with _file_open(file_path, 'rb') as f: + if file_path.endswith('.binpb'): + serialized_model = kernel_pb.MmmKernel.FromString(f.read()) + elif file_path.endswith('.textproto') or file_path.endswith('.txtpb'): + serialized_model = kernel_pb.MmmKernel() + text_format.Parse(f.read(), serialized_model) + else: + raise ValueError(f'Unsupported file type: {file_path}') + return MeridianSerde().deserialize( + serialized_model, + function_registry=function_registry, + force_deserialization=force_deserialization, + ) diff --git a/schema/serde/serde.py b/schema/serde/serde.py new file mode 100644 index 000000000..b4e32c865 --- /dev/null +++ b/schema/serde/serde.py @@ -0,0 +1,34 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Serialization and deserialization of Meridian models.""" + +import abc +from typing import Generic, TypeVar + + +WireFormat = TypeVar("WireFormat") +PythonType = TypeVar("PythonType") + + +class Serde(Generic[WireFormat, PythonType], abc.ABC): + """Serializes and deserializes a Python type into a wire format.""" + + def serialize(self, obj: PythonType, **kwargs) -> WireFormat: + """Serializes the given object into a wire format.""" + raise NotImplementedError() + + def deserialize(self, serialized: WireFormat, **kwargs) -> PythonType: + """Deserializes the given wire format into a Python object.""" + raise NotImplementedError() diff --git a/schema/serde/test_data.py b/schema/serde/test_data.py new file mode 100644 index 000000000..91d885bd4 --- /dev/null +++ b/schema/serde/test_data.py @@ -0,0 +1,4589 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test data for serde module.""" + +import inspect +import types +from typing import Any, Sequence +from unittest import mock + +from meridian import constants as c +from meridian.model import prior_distribution +from meridian.model import spec +from mmm.v1.marketing import marketing_data_pb2 as marketing_pb +from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb +import numpy as np +import tensorflow as tf +import tensorflow_probability as tfp +import xarray as xr + +from google.protobuf import text_format +from tensorflow.core.framework import tensor_pb2 # pylint: disable=g-direct-tensorflow-import +from tensorflow.core.framework import tensor_shape_pb2 # pylint: disable=g-direct-tensorflow-import +from tensorflow.core.framework import types_pb2 # pylint: disable=g-direct-tensorflow-import + +tfd = tfp.distributions +tfb = tfp.bijectors + +_MediaEffectsDist = meridian_pb.MediaEffectsDistribution +_PaidMediaPriorType = meridian_pb.PaidMediaPriorType +_NonPaidTreatmentsPriorType = meridian_pb.NonPaidTreatmentsPriorType + +# Shared constants +_TIME_STRS = ['2021-02-01', '2021-02-08'] +_MEDIA_TIME_STRS = ['2021-01-25', '2021-02-01', '2021-02-08'] +_GEO_IDS = ['geo_0', 'geo_1'] +_MEDIA_CHANNEL_PAID = ['ch_paid_0', 'ch_paid_1'] +_MEDIA_CHANNEL_ORGANIC = ['ch_organic_0', 'ch_organic_1'] +_RF_CHANNEL_PAID = ['rf_ch_paid_0', 'rf_ch_paid_1'] +_RF_CHANNEL_ORGANIC = ['rf_ch_organic_0', 'rf_ch_organic_1'] +_CONTROL_VARIABLES = ['control_0', 'control_1'] +_NON_MEDIA_TREATMENT_VARIABLES = [ + 'non_media_treatment_0', + 'non_media_treatment_1', +] + + +def make_tensor_shape_proto( + dims: Sequence[int], +) -> tensor_shape_pb2.TensorShapeProto: + tensor_shape = tensor_shape_pb2.TensorShapeProto() + for dim in dims: + tensor_shape.dim.append(tensor_shape_pb2.TensorShapeProto.Dim(size=dim)) + return tensor_shape + + +def make_tensor_proto( + dims: Sequence[int], + dtype: types_pb2.DataType = types_pb2.DT_FLOAT, + bool_vals: Sequence[bool] = (), + string_vals: Sequence[str] = (), + tensor_content: bytes = b'', +) -> tensor_pb2.TensorProto: + return tensor_pb2.TensorProto( + dtype=dtype, + tensor_shape=make_tensor_shape_proto(dims), + bool_val=bool_vals, + string_val=[x.encode() for x in string_vals], + tensor_content=tensor_content, + ) + + +def make_sample_dataset( + n_chains: int, + n_draws: int, + n_geos: int = 5, + n_controls: int = 2, + n_knots: int = 0, + n_times: int = 0, + n_media_channels: int = 0, + n_rf_channels: int = 0, + n_organic_media_channels: int = 0, + n_organic_rf_channels: int = 0, + n_non_media_channels: int = 0, +) -> xr.Dataset: + """Creates a sample dataset with all relevant Meridian dimensions. + + Args: + n_chains: The number of chains. + n_draws: The number of draws per chain. + n_geos: The number of geos. + n_controls: The number of control variables. + n_knots: The number of knots. + n_times: The number of time periods. + n_media_channels: The number of media channels. + n_rf_channels: The number of reach and frequency channels. + n_organic_media_channels: The number of organic media channels. + n_organic_rf_channels: The number of organic reach and frequency channels. + n_non_media_channels: The number of non-media channels. + + Returns: + An xarray Dataset with sample data. + """ + data_vars = { + c.STEP_SIZE: ( + [c.CHAIN, c.DRAW], + np.random.normal(size=(n_chains, n_draws)), + ), + c.TUNE: ( + [c.CHAIN, c.DRAW], + np.full((n_chains, n_draws), False), + ), + c.TARGET_LOG_PROBABILITY_TF: ( + [c.CHAIN, c.DRAW], + np.random.normal(size=(n_chains, n_draws)), + ), + c.DIVERGING: ( + [c.CHAIN, c.DRAW], + np.full((n_chains, n_draws), False), + ), + c.ACCEPT_RATIO: ( + [c.CHAIN, c.DRAW], + np.random.normal(size=(n_chains, n_draws)), + ), + c.N_STEPS: ( + [c.CHAIN, c.DRAW], + np.random.normal(size=(n_chains, n_draws)), + ), + 'is_accepted': ( + [c.CHAIN, c.DRAW], + np.full((n_chains, n_draws), True), + ), + } + coords = { + c.CHAIN: ([c.CHAIN], np.arange(n_chains)), + c.DRAW: ([c.DRAW], np.arange(n_draws)), + c.GEO: ([c.GEO], np.arange(n_geos)), + c.CONTROL_VARIABLE: ( + [c.CONTROL_VARIABLE], + np.arange(n_controls), + ), + } + + if n_knots > 0: + coords[c.KNOTS] = ([c.KNOTS], np.arange(n_knots)) + + if n_times > 0: + coords[c.TIME] = ([c.TIME], np.arange(n_times)) + + if n_media_channels > 0: + coords[c.MEDIA_CHANNEL] = ( + [c.MEDIA_CHANNEL], + np.arange(n_media_channels), + ) + + if n_rf_channels > 0: + coords[c.RF_CHANNEL] = ( + [c.RF_CHANNEL], + np.arange(n_rf_channels), + ) + + if n_organic_media_channels > 0: + coords[c.ORGANIC_MEDIA_CHANNEL] = ( + [c.ORGANIC_MEDIA_CHANNEL], + np.arange(n_organic_media_channels), + ) + + if n_organic_rf_channels > 0: + coords[c.ORGANIC_RF_CHANNEL] = ( + [c.ORGANIC_RF_CHANNEL], + np.arange(n_organic_rf_channels), + ) + + if n_non_media_channels > 0: + coords[c.NON_MEDIA_CHANNEL] = ( + [c.NON_MEDIA_CHANNEL], + np.arange(n_non_media_channels), + ) + + return xr.Dataset(data_vars, coords=coords) + + +# Marketing data test data +MOCK_INPUT_DATA_NATIONAL_MEDIA_RF_NON_REVENUE = mock.MagicMock( + kpi_type=c.NON_REVENUE, + geo=xr.DataArray(np.array(['national_geo'])), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: ['national_geo']}, + data=np.array([1.0]), + name=c.POPULATION, + ), + media=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.MEDIA_TIME: _TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array([[[41, 42], [43, 44]]]), + name=c.MEDIA, + ), + media_spend=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.TIME: _TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array([[[141, 142], [143, 144]]]), + name=c.MEDIA_SPEND, + ), + media_spend_has_geo_dimension=True, + media_spend_has_time_dimension=True, + reach=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.MEDIA_TIME: _TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([[[51.0, 52.0], [53.0, 54.0]]]), + name=c.REACH, + ), + frequency=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.MEDIA_TIME: _TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([[[1.1, 1.2], [2, 3]]]), + name=c.FREQUENCY, + ), + rf_spend=xr.DataArray( + coords={c.RF_CHANNEL: _RF_CHANNEL_PAID}, + data=np.array([502, 504]), + name=c.RF_SPEND, + ), + rf_spend_has_geo_dimension=False, + rf_spend_has_time_dimension=False, + kpi=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.TIME: _TIME_STRS, + }, + data=np.array([[1, 2]]), + name=c.KPI, + ), + revenue_per_kpi=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.TIME: _TIME_STRS, + }, + data=np.array([[11, 12]]), + name=c.REVENUE_PER_KPI, + ), + controls=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.TIME: _TIME_STRS, + c.CONTROL_VARIABLE: ['control_0', 'control_1'], + }, + data=np.array([[[31, 32], [33, 34]]]), + name=c.CONTROLS, + ), + organic_media=None, + organic_reach=None, + organic_frequency=None, + non_media_treatments=None, +) + +MOCK_PROTO_NATIONAL_MEDIA_RF_NON_REVENUE = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "national_geo" + population: 1 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + control_variables { + name: "control_1" + value: 32.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 41.0 + } + media_spend: 141.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 42.0 + } + media_spend: 142.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 51 + average_frequency: 1.1 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 52 + average_frequency: 1.2 + } + kpi { + name: "non_revenue" + non_revenue { + value: 1.0 + revenue_per_kpi: 11.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "national_geo" + population: 1 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 33.0 + } + control_variables { + name: "control_1" + value: 34.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 43.0 + } + media_spend: 143.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 44.0 + } + media_spend: 144.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 53 + average_frequency: 2.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 54 + average_frequency: 3.0 + } + kpi { + name: "non_revenue" + non_revenue { + value: 2.0 + revenue_per_kpi: 12.0 + } + } + } + marketing_data_points { + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + spend: 502.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + spend: 504.0 + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + channels: "ch_paid_1" + } + channel_dimensions { + name: "reach" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + channel_dimensions { + name: "frequency" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + control_names: "control_0" + control_names: "control_1" + kpi_type: "non_revenue" + } + """, + marketing_pb.MarketingData(), +) + +# Same as above, but with no controls. +MOCK_INPUT_DATA_NATIONAL_MEDIA_RF_NON_REVENUE_NO_CONTROLS = mock.MagicMock( + kpi_type=c.NON_REVENUE, + geo=xr.DataArray(np.array(['national_geo'])), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: ['national_geo']}, + data=np.array([1.0]), + name=c.POPULATION, + ), + media=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.MEDIA_TIME: _TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array([[[41, 42], [43, 44]]]), + name=c.MEDIA, + ), + media_spend=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.TIME: _TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array([[[141, 142], [143, 144]]]), + name=c.MEDIA_SPEND, + ), + media_spend_has_geo_dimension=True, + media_spend_has_time_dimension=True, + reach=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.MEDIA_TIME: _TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([[[51.0, 52.0], [53.0, 54.0]]]), + name=c.REACH, + ), + frequency=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.MEDIA_TIME: _TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([[[1.1, 1.2], [2, 3]]]), + name=c.FREQUENCY, + ), + rf_spend=xr.DataArray( + coords={c.RF_CHANNEL: _RF_CHANNEL_PAID}, + data=np.array([502, 504]), + name=c.RF_SPEND, + ), + rf_spend_has_geo_dimension=False, + rf_spend_has_time_dimension=False, + kpi=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.TIME: _TIME_STRS, + }, + data=np.array([[1, 2]]), + name=c.KPI, + ), + revenue_per_kpi=xr.DataArray( + coords={ + c.GEO: ['national_geo'], + c.TIME: _TIME_STRS, + }, + data=np.array([[11, 12]]), + name=c.REVENUE_PER_KPI, + ), + controls=None, + organic_media=None, + organic_reach=None, + organic_frequency=None, + non_media_treatments=None, +) + +MOCK_PROTO_NATIONAL_MEDIA_RF_NON_REVENUE_NO_CONTROLS = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "national_geo" + population: 1 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 41.0 + } + media_spend: 141.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 42.0 + } + media_spend: 142.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 51 + average_frequency: 1.1 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 52 + average_frequency: 1.2 + } + kpi { + name: "non_revenue" + non_revenue { + value: 1.0 + revenue_per_kpi: 11.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "national_geo" + population: 1 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 43.0 + } + media_spend: 143.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 44.0 + } + media_spend: 144.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 53 + average_frequency: 2.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 54 + average_frequency: 3.0 + } + kpi { + name: "non_revenue" + non_revenue { + value: 2.0 + revenue_per_kpi: 12.0 + } + } + } + marketing_data_points { + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + spend: 502.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + spend: 504.0 + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + channels: "ch_paid_1" + } + channel_dimensions { + name: "reach" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + channel_dimensions { + name: "frequency" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + kpi_type: "non_revenue" + } + """, + marketing_pb.MarketingData(), +) + +# Media, Paid, Expanded, Lagged +MOCK_INPUT_DATA_MEDIA_PAID_EXPANDED_LAGGED = mock.MagicMock( + kpi_type=c.REVENUE, + geo=xr.DataArray(np.array(_GEO_IDS)), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_MEDIA_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: _GEO_IDS}, + data=np.array([11.1, 12.2]), + name=c.POPULATION, + ), + media=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _MEDIA_TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array( + [[[39, 40], [41, 42], [43, 44]], [[45, 46], [47, 48], [49, 50]]] + ), + name=c.MEDIA, + ), + media_spend=xr.DataArray( + coords={c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID}, + data=np.array([492, 496]), + name=c.MEDIA_SPEND, + ), + media_spend_has_geo_dimension=False, + media_spend_has_time_dimension=False, + reach=None, + frequency=None, + rf_spend=None, + kpi=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + }, + data=np.array([[2, 3], [4, 5]]), + name=c.KPI, + ), + revenue_per_kpi=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + }, + data=np.ones((2, 2)), + name=c.REVENUE_PER_KPI, + ), + controls=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.CONTROL_VARIABLE: _CONTROL_VARIABLES, + }, + data=np.array([[[32, 33], [34, 35]], [[36, 37], [38, 39]]]), + name=c.CONTROLS, + ), + organic_media=None, + organic_reach=None, + organic_frequency=None, + non_media_treatments=None, +) + +MOCK_PROTO_MEDIA_PAID_EXPANDED_LAGGED = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 1 + } + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 40.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 32.0 + } + control_variables { + name: "control_1" + value: 33.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 41.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 42.0 + } + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 34.0 + } + control_variables { + name: "control_1" + value: 35.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 43.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 44.0 + } + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 1 + } + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 45.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 46.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 36.0 + } + control_variables { + name: "control_1" + value: 37.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 47.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 48.0 + } + } + kpi { + name: "revenue" + revenue { + value: 4.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 38.0 + } + control_variables { + name: "control_1" + value: 39.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 49.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 50.0 + } + } + kpi { + name: "revenue" + revenue { + value: 5.0 + } + } + } + marketing_data_points { + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + media_variables { + channel_name: "ch_paid_0" + media_spend: 492.0 + } + media_variables { + channel_name: "ch_paid_1" + media_spend: 496.0 + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 1 + day: 25 + } + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + channels: "ch_paid_1" + } + control_names: "control_0" + control_names: "control_1" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), +) + +# Media, Paid, Granular, Not Lagged +MOCK_INPUT_DATA_MEDIA_PAID_GRANULAR_NOT_LAGGED = mock.MagicMock( + kpi_type=c.REVENUE, + geo=xr.DataArray(np.array(_GEO_IDS)), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION + ), + media=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array([[[39, 40], [41, 42]], [[43, 44], [45, 46]]]), + name=c.MEDIA, + ), + media_spend=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array([[[123, 124], [125, 126]], [[127, 128], [129, 130]]]), + name=c.MEDIA_SPEND, + ), + media_spend_has_geo_dimension=True, + media_spend_has_time_dimension=True, + reach=None, + frequency=None, + rf_spend=None, + kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.array([[1, 2], [3, 4]]), + name=c.KPI, + ), + revenue_per_kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.ones((2, 2)), + name=c.REVENUE_PER_KPI, + ), + controls=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.CONTROL_VARIABLE: _CONTROL_VARIABLES, + }, + data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]), + name=c.CONTROLS, + ), + organic_media=None, + organic_reach=None, + organic_frequency=None, + non_media_treatments=None, +) + +MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED_STRING = """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + control_variables { + name: "control_1" + value: 32.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + media_spend: 123.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 40.0 + } + media_spend: 124.0 + } + kpi { + name: "revenue" + revenue { + value: 1.0 + } + } +} +marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 33.0 + } + control_variables { + name: "control_1" + value: 34.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 41.0 + } + media_spend: 125.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 42.0 + } + media_spend: 126.0 + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } +} +marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 35.0 + } + control_variables { + name: "control_1" + value: 36.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 43.0 + } + media_spend: 127.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 44.0 + } + media_spend: 128.0 + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } +} +marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 37.0 + } + control_variables { + name: "control_1" + value: 38.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 45.0 + } + media_spend: 129.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 46.0 + } + media_spend: 130.0 + } + kpi { + name: "revenue" + revenue { + value: 4.0 + } + } +} +metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + channels: "ch_paid_1" + } + control_names: "control_0" + control_names: "control_1" + kpi_type: "revenue" +} +""" + +MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED = text_format.Parse( + MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED_STRING, + marketing_pb.MarketingData(), +) + +# Media, Organic, Expanded, Lagged +MOCK_INPUT_DATA_MEDIA_ORGANIC_EXPANDED_LAGGED = mock.MagicMock( + kpi_type=c.REVENUE, + geo=xr.DataArray(np.array(_GEO_IDS)), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_MEDIA_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION + ), + media=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _MEDIA_TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]), + name=c.MEDIA, + ), + media_spend=xr.DataArray( + coords={c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID}, + data=np.array([492, 496]), + name=c.MEDIA_SPEND, + ), + media_spend_has_geo_dimension=False, + media_spend_has_time_dimension=False, + reach=None, + frequency=None, + rf_spend=None, + organic_media=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _MEDIA_TIME_STRS, + c.ORGANIC_MEDIA_CHANNEL: _MEDIA_CHANNEL_ORGANIC, + }, + data=np.array( + [[[39, 40], [41, 42], [43, 44]], [[45, 46], [47, 48], [49, 50]]] + ), + name=c.ORGANIC_MEDIA, + ), + organic_reach=None, + organic_frequency=None, + non_media_treatments=None, + kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.array([[2, 2], [3, 3]]), + name=c.KPI, + ), + revenue_per_kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.ones((2, 2)), + name=c.REVENUE_PER_KPI, + ), + controls=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.CONTROL_VARIABLE: _CONTROL_VARIABLES, + }, + data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]), + name=c.CONTROLS, + ), +) + +MOCK_PROTO_MEDIA_ORGANIC_EXPANDED_LAGGED = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 1 + } + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 1.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 2.0 + } + } + media_variables { + channel_name: "ch_organic_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + } + media_variables { + channel_name: "ch_organic_1" + scalar_metric { + name: "impressions" + value: 40.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + control_variables { + name: "control_1" + value: 32.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 3.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 4.0 + } + } + media_variables { + channel_name: "ch_organic_0" + scalar_metric { + name: "impressions" + value: 41.0 + } + } + media_variables { + channel_name: "ch_organic_1" + scalar_metric { + name: "impressions" + value: 42.0 + } + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 33.0 + } + control_variables { + name: "control_1" + value: 34.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 5.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 6.0 + } + } + media_variables { + channel_name: "ch_organic_0" + scalar_metric { + name: "impressions" + value: 43.0 + } + } + media_variables { + channel_name: "ch_organic_1" + scalar_metric { + name: "impressions" + value: 44.0 + } + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 1 + } + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 7.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 8.0 + } + } + media_variables { + channel_name: "ch_organic_0" + scalar_metric { + name: "impressions" + value: 45.0 + } + } + media_variables { + channel_name: "ch_organic_1" + scalar_metric { + name: "impressions" + value: 46.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 35.0 + } + control_variables { + name: "control_1" + value: 36.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 9.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 10.0 + } + } + media_variables { + channel_name: "ch_organic_0" + scalar_metric { + name: "impressions" + value: 47.0 + } + } + media_variables { + channel_name: "ch_organic_1" + scalar_metric { + name: "impressions" + value: 48.0 + } + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 37.0 + } + control_variables { + name: "control_1" + value: 38.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 11.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 12.0 + } + } + media_variables { + channel_name: "ch_organic_0" + scalar_metric { + name: "impressions" + value: 49.0 + } + } + media_variables { + channel_name: "ch_organic_1" + scalar_metric { + name: "impressions" + value: 50.0 + } + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + marketing_data_points { + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + media_variables { + channel_name: "ch_paid_0" + media_spend: 492.0 + } + media_variables { + channel_name: "ch_paid_1" + media_spend: 496.0 + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 1 + day: 25 + } + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + channels: "ch_paid_1" + } + channel_dimensions { + name: "organic_media" + channels: "ch_organic_0" + channels: "ch_organic_1" + } + control_names: "control_0" + control_names: "control_1" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), +) + +# Media, Organic, Granular, Not Lagged +MOCK_INPUT_DATA_MEDIA_ORGANIC_GRANULAR_NOT_LAGGED = mock.MagicMock( + kpi_type=c.REVENUE, + geo=xr.DataArray(np.array(_GEO_IDS)), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION + ), + media=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), + name=c.MEDIA, + ), + media_spend=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array([[[123, 124], [125, 126]], [[127, 128], [129, 130]]]), + name=c.MEDIA_SPEND, + ), + media_spend_has_geo_dimension=True, + media_spend_has_time_dimension=True, + reach=None, + frequency=None, + rf_spend=None, + organic_media=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _TIME_STRS, + c.ORGANIC_MEDIA_CHANNEL: _MEDIA_CHANNEL_ORGANIC, + }, + data=np.array([[[39, 40], [41, 42]], [[43, 44], [45, 46]]]), + name=c.ORGANIC_MEDIA, + ), + organic_reach=None, + organic_frequency=None, + non_media_treatments=None, + kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.array([[2, 2], [3, 3]]), + name=c.KPI, + ), + revenue_per_kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.ones((2, 2)), + name=c.REVENUE_PER_KPI, + ), + controls=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.CONTROL_VARIABLE: _CONTROL_VARIABLES, + }, + data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]), + name=c.CONTROLS, + ), +) + +MOCK_PROTO_MEDIA_ORGANIC_GRANULAR_NOT_LAGGED = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + control_variables { + name: "control_1" + value: 32.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 1.0 + } + media_spend: 123.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 2.0 + } + media_spend: 124.0 + } + media_variables { + channel_name: "ch_organic_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + } + media_variables { + channel_name: "ch_organic_1" + scalar_metric { + name: "impressions" + value: 40.0 + } + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 33.0 + } + control_variables { + name: "control_1" + value: 34.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 3.0 + } + media_spend: 125.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 4.0 + } + media_spend: 126.0 + } + media_variables { + channel_name: "ch_organic_0" + scalar_metric { + name: "impressions" + value: 41.0 + } + } + media_variables { + channel_name: "ch_organic_1" + scalar_metric { + name: "impressions" + value: 42.0 + } + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 35.0 + } + control_variables { + name: "control_1" + value: 36.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 5.0 + } + media_spend: 127.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 6.0 + } + media_spend: 128.0 + } + media_variables { + channel_name: "ch_organic_0" + scalar_metric { + name: "impressions" + value: 43.0 + } + } + media_variables { + channel_name: "ch_organic_1" + scalar_metric { + name: "impressions" + value: 44.0 + } + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 37.0 + } + control_variables { + name: "control_1" + value: 38.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 7.0 + } + media_spend: 129.0 + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 8.0 + } + media_spend: 130.0 + } + media_variables { + channel_name: "ch_organic_0" + scalar_metric { + name: "impressions" + value: 45.0 + } + } + media_variables { + channel_name: "ch_organic_1" + scalar_metric { + name: "impressions" + value: 46.0 + } + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + channels: "ch_paid_1" + } + channel_dimensions { + name: "organic_media" + channels: "ch_organic_0" + channels: "ch_organic_1" + } + control_names: "control_0" + control_names: "control_1" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), +) + +# Reach and Frequency, Paid, Expanded, Lagged +MOCK_INPUT_DATA_RF_PAID_EXPANDED_LAGGED = mock.MagicMock( + kpi_type=c.REVENUE, + geo=xr.DataArray(np.array(_GEO_IDS)), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_MEDIA_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: _GEO_IDS}, + data=np.array([11.1, 12.2]), + name=c.POPULATION, + ), + media=None, + media_spend=None, + reach=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _MEDIA_TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array( + [[[51, 52], [53, 54], [55, 56]], [[57, 58], [59, 60], [61, 62]]] + ), + name=c.REACH, + ), + frequency=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _MEDIA_TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([ + [[1.1, 1.2], [1.3, 1.4], [1.5, 1.6]], + [[1.7, 1.8], [1.9, 2.0], [2.1, 2.2]], + ]), + name=c.FREQUENCY, + ), + rf_spend=xr.DataArray( + coords={c.RF_CHANNEL: _RF_CHANNEL_PAID}, + data=np.array([1004, 1008]), + name=c.RF_SPEND, + ), + rf_spend_has_geo_dimension=False, + rf_spend_has_time_dimension=False, + kpi=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + }, + data=np.array([[2, 3], [4, 5]]), + name=c.KPI, + ), + revenue_per_kpi=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + }, + data=np.ones((2, 2)), + name=c.REVENUE_PER_KPI, + ), + controls=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.CONTROL_VARIABLE: _CONTROL_VARIABLES, + }, + data=np.array([[[32, 33], [34, 35]], [[36, 37], [38, 39]]]), + name=c.CONTROLS, + ), + organic_media=None, + organic_reach=None, + organic_frequency=None, + non_media_treatments=None, +) + +MOCK_PROTO_RF_PAID_EXPANDED_LAGGED = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 1 + } + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 51 + average_frequency: 1.1 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 52 + average_frequency: 1.2 + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 32.0 + } + control_variables { + name: "control_1" + value: 33.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 53 + average_frequency: 1.3 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 54 + average_frequency: 1.4 + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 34.0 + } + control_variables { + name: "control_1" + value: 35.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 55 + average_frequency: 1.5 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 56 + average_frequency: 1.6 + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 1 + } + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 57 + average_frequency: 1.7 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 58 + average_frequency: 1.8 + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 36.0 + } + control_variables { + name: "control_1" + value: 37.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 59 + average_frequency: 1.9 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 60 + average_frequency: 2.0 + } + kpi { + name: "revenue" + revenue { + value: 4.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 38.0 + } + control_variables { + name: "control_1" + value: 39.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 61 + average_frequency: 2.1 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 62 + average_frequency: 2.2 + } + kpi { + name: "revenue" + revenue { + value: 5.0 + } + } + } + marketing_data_points { + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + spend: 1004.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + spend: 1008.0 + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 1 + day: 25 + } + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "reach" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + channel_dimensions { + name: "frequency" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + control_names: "control_0" + control_names: "control_1" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), +) + +# Reach and Frequency, Paid, Granular, Not Lagged +MOCK_INPUT_DATA_RF_PAID_GRANULAR_NOT_LAGGED = mock.MagicMock( + kpi_type=c.REVENUE, + geo=xr.DataArray(np.array(_GEO_IDS)), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION + ), + media=None, + media_spend=None, + reach=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array( + [[[51.0, 52.0], [53.0, 54.0]], [[55.0, 56.0], [57.0, 58.0]]] + ), + name=c.REACH, + ), + frequency=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([[[1.1, 1.2], [2, 3]], [[4, 5], [6, 7]]]), + name=c.FREQUENCY, + ), + rf_spend=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([[[252, 253], [254, 255]], [[256, 257], [258, 259]]]), + name=c.RF_SPEND, + ), + rf_spend_has_geo_dimension=True, + rf_spend_has_time_dimension=True, + kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.array([[1, 2], [3, 4]]), + name=c.KPI, + ), + revenue_per_kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.ones((2, 2)), + name=c.REVENUE_PER_KPI, + ), + controls=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.CONTROL_VARIABLE: _CONTROL_VARIABLES, + }, + data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]), + name=c.CONTROLS, + ), + organic_media=None, + organic_reach=None, + organic_frequency=None, + non_media_treatments=None, +) + +MOCK_PROTO_RF_PAID_GRANULAR_NOT_LAGGED = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + control_variables { + name: "control_1" + value: 32.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 51 + average_frequency: 1.1 + spend: 252.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 52 + average_frequency: 1.2 + spend: 253.0 + } + kpi { + name: "revenue" + revenue { + value: 1.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 33.0 + } + control_variables { + name: "control_1" + value: 34.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 53 + average_frequency: 2.0 + spend: 254.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 54 + average_frequency: 3.0 + spend: 255.0 + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 35.0 + } + control_variables { + name: "control_1" + value: 36.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 55 + average_frequency: 4.0 + spend: 256.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 56 + average_frequency: 5.0 + spend: 257.0 + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 37.0 + } + control_variables { + name: "control_1" + value: 38.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 57 + average_frequency: 6.0 + spend: 258.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 58 + average_frequency: 7.0 + spend: 259.0 + } + kpi { + name: "revenue" + revenue { + value: 4.0 + } + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "reach" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + channel_dimensions { + name: "frequency" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + control_names: "control_0" + control_names: "control_1" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), +) + +# Reach and Frequency, Organic, Expanded, Lagged +MOCK_INPUT_DATA_RF_ORGANIC_EXPANDED_LAGGED = mock.MagicMock( + kpi_type=c.REVENUE, + geo=xr.DataArray(np.array(_GEO_IDS)), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_MEDIA_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION + ), + media=None, + media_spend=None, + reach=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _MEDIA_TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]), + name=c.REACH, + ), + frequency=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _MEDIA_TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([ + [[2.1, 2.2], [2.3, 2.4], [2.5, 2.6]], + [[2.7, 2.8], [2.9, 3.0], [3.1, 3.2]], + ]), + name=c.FREQUENCY, + ), + rf_spend=xr.DataArray( + coords={c.RF_CHANNEL: _RF_CHANNEL_PAID}, + data=np.array([1004, 1008]), + name=c.RF_SPEND, + ), + rf_spend_has_geo_dimension=False, + rf_spend_has_time_dimension=False, + organic_media=None, + organic_reach=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _MEDIA_TIME_STRS, + c.ORGANIC_RF_CHANNEL: _RF_CHANNEL_ORGANIC, + }, + data=np.array( + [[[51, 52], [53, 54], [55, 56]], [[57, 58], [59, 60], [61, 62]]] + ), + name=c.ORGANIC_REACH, + ), + organic_frequency=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _MEDIA_TIME_STRS, + c.ORGANIC_RF_CHANNEL: _RF_CHANNEL_ORGANIC, + }, + data=np.array([ + [[1.1, 1.2], [1.3, 1.4], [1.5, 1.6]], + [[1.7, 1.8], [1.9, 2.0], [2.1, 2.2]], + ]), + name=c.ORGANIC_FREQUENCY, + ), + non_media_treatments=None, + kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.array([[2, 2], [3, 3]]), + name=c.KPI, + ), + revenue_per_kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.ones((2, 2)), + name=c.REVENUE_PER_KPI, + ), + controls=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.CONTROL_VARIABLE: _CONTROL_VARIABLES, + }, + data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]), + name=c.CONTROLS, + ), +) + +MOCK_PROTO_RF_ORGANIC_EXPANDED_LAGGED = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 1 + } + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 1 + average_frequency: 2.1 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 2 + average_frequency: 2.2 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_0" + reach: 51 + average_frequency: 1.1 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_1" + reach: 52 + average_frequency: 1.2 + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + control_variables { + name: "control_1" + value: 32.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 3 + average_frequency: 2.3 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 4 + average_frequency: 2.4 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_0" + reach: 53 + average_frequency: 1.3 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_1" + reach: 54 + average_frequency: 1.4 + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 33.0 + } + control_variables { + name: "control_1" + value: 34.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 5 + average_frequency: 2.5 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 6 + average_frequency: 2.6 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_0" + reach: 55 + average_frequency: 1.5 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_1" + reach: 56 + average_frequency: 1.6 + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 1 + } + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 7 + average_frequency: 2.7 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 8 + average_frequency: 2.8 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_0" + reach: 57 + average_frequency: 1.7 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_1" + reach: 58 + average_frequency: 1.8 + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 35.0 + } + control_variables { + name: "control_1" + value: 36.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 9 + average_frequency: 2.9 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 10 + average_frequency: 3.0 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_0" + reach: 59 + average_frequency: 1.9 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_1" + reach: 60 + average_frequency: 2.0 + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 37.0 + } + control_variables { + name: "control_1" + value: 38.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 11 + average_frequency: 3.1 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 12 + average_frequency: 3.2 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_0" + reach: 61 + average_frequency: 2.1 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_1" + reach: 62 + average_frequency: 2.2 + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + marketing_data_points { + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + spend: 1004.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + spend: 1008.0 + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 1 + day: 25 + } + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "reach" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + channel_dimensions { + name: "frequency" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + channel_dimensions { + name: "organic_reach" + channels: "rf_ch_organic_0" + channels: "rf_ch_organic_1" + } + channel_dimensions { + name: "organic_frequency" + channels: "rf_ch_organic_0" + channels: "rf_ch_organic_1" + } + control_names: "control_0" + control_names: "control_1" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), +) + +# Reach and Frequency, Organic, Granular, Not Lagged +MOCK_INPUT_DATA_RF_ORGANIC_GRANULAR_NOT_LAGGED = mock.MagicMock( + kpi_type=c.REVENUE, + geo=xr.DataArray(np.array(_GEO_IDS)), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION + ), + media=None, + media_spend=None, + reach=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]), + name=c.REACH, + ), + frequency=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([[[2.1, 2.2], [3, 4]], [[5, 6], [7, 8]]]), + name=c.FREQUENCY, + ), + rf_spend=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.RF_CHANNEL: _RF_CHANNEL_PAID, + }, + data=np.array([[[252, 253], [254, 255]], [[256, 257], [258, 259]]]), + name=c.RF_SPEND, + ), + rf_spend_has_geo_dimension=True, + rf_spend_has_time_dimension=True, + organic_media=None, + organic_reach=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _TIME_STRS, + c.ORGANIC_RF_CHANNEL: _RF_CHANNEL_ORGANIC, + }, + data=np.array( + [[[51.0, 52.0], [53.0, 54.0]], [[55.0, 56.0], [57.0, 58.0]]] + ), + name=c.ORGANIC_REACH, + ), + organic_frequency=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _TIME_STRS, + c.ORGANIC_RF_CHANNEL: _RF_CHANNEL_ORGANIC, + }, + data=np.array( + [[[51.0, 52.0], [53.0, 54.0]], [[55.0, 56.0], [57.0, 58.0]]] + ), + name=c.ORGANIC_FREQUENCY, + ), + non_media_treatments=None, + kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.array([[2, 2], [3, 3]]), + name=c.KPI, + ), + revenue_per_kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.ones((2, 2)), + name=c.REVENUE_PER_KPI, + ), + controls=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.CONTROL_VARIABLE: _CONTROL_VARIABLES, + }, + data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]), + name=c.CONTROLS, + ), +) + +MOCK_PROTO_RF_ORGANIC_GRANULAR_NOT_LAGGED = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + control_variables { + name: "control_1" + value: 32.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 1 + average_frequency: 2.1 + spend: 252.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 2 + average_frequency: 2.2 + spend: 253.0 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_0" + reach: 51 + average_frequency: 51.0 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_1" + reach: 52 + average_frequency: 52.0 + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 33.0 + } + control_variables { + name: "control_1" + value: 34.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 3 + average_frequency: 3.0 + spend: 254.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 4 + average_frequency: 4.0 + spend: 255.0 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_0" + reach: 53 + average_frequency: 53.0 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_1" + reach: 54 + average_frequency: 54.0 + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 35.0 + } + control_variables { + name: "control_1" + value: 36.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 5 + average_frequency: 5.0 + spend: 256.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 6 + average_frequency: 6.0 + spend: 257.0 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_0" + reach: 55 + average_frequency: 55.0 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_1" + reach: 56 + average_frequency: 56.0 + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 37.0 + } + control_variables { + name: "control_1" + value: 38.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_0" + reach: 7 + average_frequency: 7.0 + spend: 258.0 + } + reach_frequency_variables { + channel_name: "rf_ch_paid_1" + reach: 8 + average_frequency: 8.0 + spend: 259.0 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_0" + reach: 57 + average_frequency: 57.0 + } + reach_frequency_variables { + channel_name: "rf_ch_organic_1" + reach: 58 + average_frequency: 58.0 + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "reach" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + channel_dimensions { + name: "frequency" + channels: "rf_ch_paid_0" + channels: "rf_ch_paid_1" + } + channel_dimensions { + name: "organic_reach" + channels: "rf_ch_organic_0" + channels: "rf_ch_organic_1" + } + channel_dimensions { + name: "organic_frequency" + channels: "rf_ch_organic_0" + channels: "rf_ch_organic_1" + } + control_names: "control_0" + control_names: "control_1" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), +) + +MOCK_INPUT_DATA_NON_MEDIA_TREATMENTS = mock.MagicMock( + kpi_type=c.REVENUE, + geo=xr.DataArray(np.array(_GEO_IDS)), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_MEDIA_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: _GEO_IDS}, data=np.array([11.1, 12.2]), name=c.POPULATION + ), + kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.array([[1, 2], [3, 4]]), + name=c.KPI, + ), + revenue_per_kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.ones((2, 2)), + name=c.REVENUE_PER_KPI, + ), + controls=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.CONTROL_VARIABLE: _CONTROL_VARIABLES, + }, + data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]), + name=c.CONTROLS, + ), + non_media_treatments=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.NON_MEDIA_CHANNEL: _NON_MEDIA_TREATMENT_VARIABLES, + }, + data=np.array([[[61, 62], [63, 64]], [[65, 66], [67, 68]]]), + name=c.NON_MEDIA_TREATMENTS, + ), + media=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _MEDIA_TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array( + [[[39, 40], [41, 42], [43, 44]], [[45, 46], [47, 48], [49, 50]]] + ), + name=c.MEDIA, + ), + media_spend=xr.DataArray( + coords={c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID}, + data=np.array([492, 496]), + name=c.MEDIA_SPEND, + ), + media_spend_has_geo_dimension=False, + media_spend_has_time_dimension=False, + reach=None, + frequency=None, + rf_spend=None, + rf_spend_has_geo_dimension=False, + rf_spend_has_time_dimension=False, + organic_media=None, + organic_reach=None, + organic_frequency=None, +) + +MOCK_PROTO_NON_MEDIA_TREATMENTS = text_format.Parse( + """ + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 1 + } + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 40.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 31.0 + } + control_variables { + name: "control_1" + value: 32.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 41.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 42.0 + } + } + non_media_treatment_variables { + name: "non_media_treatment_0" + value: 61.0 + } + non_media_treatment_variables { + name: "non_media_treatment_1" + value: 62.0 + } + kpi { + name: "revenue" + revenue { + value: 1.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_0" + population: 11 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 33.0 + } + control_variables { + name: "control_1" + value: 34.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 43.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 44.0 + } + } + non_media_treatment_variables { + name: "non_media_treatment_0" + value: 63.0 + } + non_media_treatment_variables { + name: "non_media_treatment_1" + value: 64.0 + } + kpi { + name: "revenue" + revenue { + value: 2.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 1 + } + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 45.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 46.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 8 + } + } + control_variables { + name: "control_0" + value: 35.0 + } + control_variables { + name: "control_1" + value: 36.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 47.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 48.0 + } + } + non_media_treatment_variables { + name: "non_media_treatment_0" + value: 65.0 + } + non_media_treatment_variables { + name: "non_media_treatment_1" + value: 66.0 + } + kpi { + name: "revenue" + revenue { + value: 3.0 + } + } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 12 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 37.0 + } + control_variables { + name: "control_1" + value: 38.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 49.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 50.0 + } + } + non_media_treatment_variables { + name: "non_media_treatment_0" + value: 67.0 + } + non_media_treatment_variables { + name: "non_media_treatment_1" + value: 68.0 + } + kpi { + name: "revenue" + revenue { + value: 4.0 + } + } + } + marketing_data_points { + date_interval { + start_date { + year: 2021 + month: 1 + day: 25 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + media_variables { + channel_name: "ch_paid_0" + media_spend: 492.0 + } + media_variables { + channel_name: "ch_paid_1" + media_spend: 496.0 + } + } + metadata { + time_dimensions { + name: "time" + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + time_dimensions { + name: "media_time" + dates { + year: 2021 + month: 1 + day: 25 + } + dates { + year: 2021 + month: 2 + day: 1 + } + dates { + year: 2021 + month: 2 + day: 8 + } + } + channel_dimensions { + name: "media" + channels: "ch_paid_0" + channels: "ch_paid_1" + } + control_names: "control_0" + control_names: "control_1" + non_media_treatment_names: "non_media_treatment_0" + non_media_treatment_names: "non_media_treatment_1" + kpi_type: "revenue" + } + """, + marketing_pb.MarketingData(), +) + +MOCK_INPUT_DATA_NO_REVENUE_PER_KPI = mock.MagicMock( + kpi_type=c.NON_REVENUE, + geo=xr.DataArray(np.array(_GEO_IDS)), + time=xr.DataArray(np.array(_TIME_STRS)), + media_time=xr.DataArray(np.array(_TIME_STRS)), + population=xr.DataArray( + coords={c.GEO: _GEO_IDS}, + data=np.array([1000.0, 1200.0]), + name=c.POPULATION, + ), + kpi=xr.DataArray( + coords={c.GEO: _GEO_IDS, c.TIME: _TIME_STRS}, + data=np.array([[50, 60], [70, 80]]), + name=c.KPI, + ), + revenue_per_kpi=None, + controls=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.TIME: _TIME_STRS, + c.CONTROL_VARIABLE: _CONTROL_VARIABLES, + }, + data=np.array([[[31, 32], [33, 34]], [[35, 36], [37, 38]]]), + name=c.CONTROLS, + ), + media=xr.DataArray( + coords={ + c.GEO: _GEO_IDS, + c.MEDIA_TIME: _TIME_STRS, + c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID, + }, + data=np.array([[[39, 40], [41, 42]], [[43, 44], [45, 46]]]), + name=c.MEDIA, + ), + media_spend=xr.DataArray( + coords={c.MEDIA_CHANNEL: _MEDIA_CHANNEL_PAID}, + data=np.array([492, 496]), + name=c.MEDIA_SPEND, + ), + media_spend_has_geo_dimension=False, + media_spend_has_time_dimension=False, + reach=None, + frequency=None, + rf_spend=None, + organic_media=None, + organic_reach=None, + organic_frequency=None, + non_media_treatments=None, +) + +# Expected Protobuf (Textproto format) +MOCK_PROTO_NO_REVENUE_PER_KPI = text_format.Parse( + """ + marketing_data_points { + geo_info { geo_id: "geo_0" population: 1000 } + date_interval { + start_date { year: 2021 month: 2 day: 1 } + end_date { year: 2021 month: 2 day: 8 } + } + control_variables { name: "control_0" value: 31.0 } + control_variables { name: "control_1" value: 32.0 } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 39.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 40.0 + } + } + kpi { name: "non_revenue" non_revenue { value: 50.0 } } + } + marketing_data_points { + geo_info { geo_id: "geo_0" population: 1000 } + date_interval { + start_date { year: 2021 month: 2 day: 8 } + end_date { year: 2021 month: 2 day: 15 } + } + control_variables { name: "control_0" value: 33.0 } + control_variables { name: "control_1" value: 34.0 } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 41.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 42.0 + } + } + kpi { name: "non_revenue" non_revenue { value: 60.0 } } + } + marketing_data_points { + geo_info { geo_id: "geo_1" population: 1200 } + date_interval { + start_date { year: 2021 month: 2 day: 1 } + end_date { year: 2021 month: 2 day: 8 } + } + control_variables { name: "control_0" value: 35.0 } + control_variables { name: "control_1" value: 36.0 } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 43.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 44.0 + } + } + kpi { name: "non_revenue" non_revenue { value: 70.0 } } + } + marketing_data_points { + geo_info { + geo_id: "geo_1" + population: 1200 + } + date_interval { + start_date { + year: 2021 + month: 2 + day: 8 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + control_variables { + name: "control_0" + value: 37.0 + } + control_variables { + name: "control_1" + value: 38.0 + } + media_variables { + channel_name: "ch_paid_0" + scalar_metric { + name: "impressions" + value: 45.0 + } + } + media_variables { + channel_name: "ch_paid_1" + scalar_metric { + name: "impressions" + value: 46.0 + } + } + kpi { name: "non_revenue" non_revenue { value: 80.0 } } + } + marketing_data_points { + date_interval { + start_date { + year: 2021 + month: 2 + day: 1 + } + end_date { + year: 2021 + month: 2 + day: 15 + } + } + media_variables { + channel_name: "ch_paid_0" + media_spend: 492.0 + } + media_variables { + channel_name: "ch_paid_1" + media_spend: 496.0 + } + } + metadata { + time_dimensions { name: "time" dates { year: 2021 month: 2 day: 1 } dates { year: 2021 month: 2 day: 8} } + time_dimensions { name: "media_time" dates { year: 2021 month: 2 day: 1 } dates { year: 2021 month: 2 day: 8 } } + channel_dimensions { name: "media" channels: "ch_paid_0" channels: "ch_paid_1" } + control_names: "control_0" + control_names: "control_1" + kpi_type: "non_revenue" + } + """, + marketing_pb.MarketingData(), +) + + +# Hyperparameters test data +DEFAULT_MODEL_SPEC = spec.ModelSpec() + +DEFAULT_HYPERPARAMETERS_PROTO = meridian_pb.Hyperparameters( + media_effects_dist=_MediaEffectsDist.LOG_NORMAL, + hill_before_adstock=False, + max_lag=8, + unique_sigma_for_each_geo=False, + media_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED, + rf_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED, + paid_media_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED, + organic_media_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION, + organic_rf_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION, + non_media_treatments_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION, + enable_aks=False, + global_adstock_decay='geometric', +) + +CUSTOM_MODEL_SPEC_1 = spec.ModelSpec( + prior=prior_distribution.PriorDistribution(), + media_effects_dist=c.MEDIA_EFFECTS_NORMAL, + hill_before_adstock=True, + max_lag=777, + unique_sigma_for_each_geo=True, + media_prior_type=c.TREATMENT_PRIOR_TYPE_MROI, + rf_prior_type=c.TREATMENT_PRIOR_TYPE_MROI, + knots=2, + baseline_geo='baseline_geo', + roi_calibration_period=None, + rf_roi_calibration_period=None, + holdout_id=None, + control_population_scaling_id=None, + adstock_decay_spec='binomial', +) + +CUSTOM_HYPERPARAMETERS_PROTO_1 = meridian_pb.Hyperparameters( + media_effects_dist=_MediaEffectsDist.NORMAL, + hill_before_adstock=True, + max_lag=777, + unique_sigma_for_each_geo=True, + media_prior_type=_PaidMediaPriorType.MROI, + rf_prior_type=_PaidMediaPriorType.MROI, + paid_media_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED, + organic_media_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION, + organic_rf_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION, + non_media_treatments_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION, + knots=[2], + baseline_geo_string='baseline_geo', + enable_aks=False, + global_adstock_decay='binomial', +) + + +CUSTOM_MODEL_SPEC_2 = spec.ModelSpec( + prior=prior_distribution.PriorDistribution(), + media_effects_dist='log_normal', + hill_before_adstock=True, + max_lag=777, + unique_sigma_for_each_geo=True, + media_prior_type=c.TREATMENT_PRIOR_TYPE_ROI, + rf_prior_type=c.TREATMENT_PRIOR_TYPE_ROI, + organic_media_prior_type=c.TREATMENT_PRIOR_TYPE_CONTRIBUTION, + organic_rf_prior_type=c.TREATMENT_PRIOR_TYPE_COEFFICIENT, + non_media_treatments_prior_type=c.TREATMENT_PRIOR_TYPE_COEFFICIENT, + knots=[1, 5, 8], + baseline_geo=3, + roi_calibration_period=np.full((2, 3), True), + rf_roi_calibration_period=np.full((4, 5), False), + holdout_id=np.full((6,), True), + control_population_scaling_id=np.full((7, 8), False), + non_media_population_scaling_id=np.full((9, 10), False), + adstock_decay_spec={'ch_paid_0': 'binomial', 'rf_ch_paid_1': 'geometric'}, +) + +CUSTOM_HYPERPARAMETERS_PROTO_2 = meridian_pb.Hyperparameters( + media_effects_dist=_MediaEffectsDist.LOG_NORMAL, + hill_before_adstock=True, + max_lag=777, + unique_sigma_for_each_geo=True, + media_prior_type=_PaidMediaPriorType.ROI, + rf_prior_type=_PaidMediaPriorType.ROI, + paid_media_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED, + organic_media_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION, + organic_rf_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_COEFFICIENT, + non_media_treatments_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_COEFFICIENT, + knots=[1, 5, 8], + baseline_geo_int=3, + roi_calibration_period=make_tensor_proto( + dims=[2, 3], + dtype=types_pb2.DT_BOOL, + bool_vals=[True] * (2 * 3), + ), + rf_roi_calibration_period=make_tensor_proto( + dims=[4, 5], + dtype=types_pb2.DT_BOOL, + bool_vals=[False] * (4 * 5), + ), + holdout_id=make_tensor_proto( + dims=[6], + dtype=types_pb2.DT_BOOL, + bool_vals=[True] * 6, + ), + control_population_scaling_id=make_tensor_proto( + dims=[7, 8], + dtype=types_pb2.DT_BOOL, + bool_vals=[False] * (7 * 8), + ), + non_media_population_scaling_id=make_tensor_proto( + dims=[9, 10], + dtype=types_pb2.DT_BOOL, + bool_vals=[False] * (9 * 10), + ), + enable_aks=False, + adstock_decay_by_channel=meridian_pb.AdstockDecayByChannel( + channel_decays={'ch_paid_0': 'binomial', 'rf_ch_paid_1': 'geometric'} + ), +) + +CUSTOM_MODEL_SPEC_3 = spec.ModelSpec( + prior=prior_distribution.PriorDistribution(), + media_effects_dist=c.MEDIA_EFFECTS_NORMAL, + hill_before_adstock=True, + max_lag=777, + unique_sigma_for_each_geo=True, + media_prior_type=c.TREATMENT_PRIOR_TYPE_MROI, + rf_prior_type=c.TREATMENT_PRIOR_TYPE_MROI, + organic_media_prior_type=c.TREATMENT_PRIOR_TYPE_CONTRIBUTION, + organic_rf_prior_type=c.TREATMENT_PRIOR_TYPE_CONTRIBUTION, + non_media_treatments_prior_type=c.TREATMENT_PRIOR_TYPE_CONTRIBUTION, + baseline_geo='baseline_geo', + roi_calibration_period=None, + rf_roi_calibration_period=None, + holdout_id=None, + control_population_scaling_id=None, + enable_aks=True, +) + +CUSTOM_HYPERPARAMETERS_PROTO_3 = meridian_pb.Hyperparameters( + media_effects_dist=_MediaEffectsDist.NORMAL, + hill_before_adstock=True, + max_lag=777, + unique_sigma_for_each_geo=True, + media_prior_type=_PaidMediaPriorType.MROI, + rf_prior_type=_PaidMediaPriorType.MROI, + paid_media_prior_type=_PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED, + organic_media_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION, + organic_rf_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION, + non_media_treatments_prior_type=_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION, + baseline_geo_string='baseline_geo', + enable_aks=True, + global_adstock_decay='geometric', +) + + +def _create_tfp_params_from_dict( + param_dict: dict[str, Any], + distribution: tfd.Distribution | tfb.Bijector, +) -> dict[str, meridian_pb.TfpParameterValue]: + param_dict.update({ + 'validate_args': False, + }) + return { + key: _create_tfp_param(key, value, distribution) + for key, value in param_dict.items() + } + + +def create_distribution_proto( + distribution_type: str, **kwargs +) -> meridian_pb.TfpDistribution: + distribution = getattr(tfd, distribution_type) + return meridian_pb.TfpDistribution( + distribution_type=distribution_type, + parameters=_create_tfp_params_from_dict(kwargs, distribution), + ) + + +def create_bijector_proto( + bijector_type: str, **kwargs +) -> meridian_pb.TfpBijector: + bijector = getattr(tfb, bijector_type) + return meridian_pb.TfpBijector( + bijector_type=bijector_type, + parameters=_create_tfp_params_from_dict(kwargs, bijector), + ) + + +def _create_tfp_param(param_name, param_value, distribution): + """Creates a TfpParameterValue object based on the input value's type.""" + match param_value: + case float(): + return meridian_pb.TfpParameterValue(scalar_value=param_value) + case int(): + return meridian_pb.TfpParameterValue(int_value=param_value) + case bool(): + return meridian_pb.TfpParameterValue(bool_value=param_value) + case str(): + return meridian_pb.TfpParameterValue(string_value=param_value) + case None: + return meridian_pb.TfpParameterValue(none_value=True) + case list(): + value_generator = ( + _create_tfp_param(param_name, v, distribution) for v in param_value + ) + tfp_list_value = meridian_pb.TfpParameterValue.List( + values=value_generator + ) + return meridian_pb.TfpParameterValue(list_value=tfp_list_value) + case dict(): + dict_value = { + key: _create_tfp_param(key, v, distribution) + for key, v in param_value.items() + } + return meridian_pb.TfpParameterValue(dict_value=dict_value) + case tf.Tensor(): + tensor_value = tf.make_tensor_proto(param_value) + return meridian_pb.TfpParameterValue(tensor_value=tensor_value) + case meridian_pb.TfpDistribution(): + return meridian_pb.TfpParameterValue(distribution_value=param_value) + case meridian_pb.TfpBijector(): + return meridian_pb.TfpParameterValue(bijector_value=param_value) + case tfd.ReparameterizationType(): + fully_reparameterized = param_value == tfd.FULLY_REPARAMETERIZED + return meridian_pb.TfpParameterValue( + fully_reparameterized=fully_reparameterized + ) + case types.FunctionType(): + # Add custom functions used for tests. + test_registry = {'distribution_fn': distribution_fn} + + for function_key, func in test_registry.items(): + if func == param_value: # pylint: disable=comparison-with-callable + return meridian_pb.TfpParameterValue( + function_param=meridian_pb.TfpParameterValue.FunctionParam( + function_key=function_key + ) + ) + # Function has default value. + signature = inspect.signature(distribution.__init__) + param = signature.parameters[param_name] + if param.default: + return meridian_pb.TfpParameterValue( + function_param=meridian_pb.TfpParameterValue.FunctionParam( + uses_default=True + ) + ) + raise TypeError( + f'No function found in registry for "{param_value.__name__}"' + ) + case _: + # Handle unsupported types. + raise TypeError(f'Unsupported type: {type(param_value)}') + + +# Arbitrary function used for testing `tfd.Autoregressive`. +# https://github.com/tensorflow/probability/blob/65f265c62bb1e2d15ef3e25104afb245a6d52429/tensorflow_probability/python/distributions/autoregressive_test.py#L89 +def distribution_fn(sample0): + num_frames = sample0.shape[-1] + mask = tf.one_hot(0, num_frames)[:, tf.newaxis] + probs = tf.roll(tf.one_hot(sample0, 3), shift=1, axis=-2) + probs = probs * (1.0 - mask) + tf.convert_to_tensor([0.5, 0.5, 0]) * mask + return tfd.Independent( + tfd.Categorical(probs=probs), reinterpreted_batch_ndims=1 + ) + + +def get_default_kwargs_split_fn(): + """Returns the default `kwargs_split_fn` used for tfd Distributions.""" + # `dist` can be any Distribution that has kwargs_split_fn in its signature. + dist = tfd.TransformedDistribution + signature = inspect.signature(dist.__init__) + kwargs_split_fn_param = signature.parameters['kwargs_split_fn'] + return kwargs_split_fn_param.default diff --git a/schema/serde/testdata/autoregressive.textproto b/schema/serde/testdata/autoregressive.textproto new file mode 100644 index 000000000..8078be16f --- /dev/null +++ b/schema/serde/testdata/autoregressive.textproto @@ -0,0 +1,50 @@ +# proto-file: third_party/py/meridian/proto/mmm/v1/model/meridian/meridian_model.proto +# proto-message: TfpDistribution + +distribution_type: "Autoregressive" +parameters { + key: "distribution_fn" + value { + function_param { + function_key: "distribution_fn" + } + } +} +parameters { + key: "sample0" + value { + tensor_value { + dtype: 3 + tensor_shape { + dim { + size: 4 + } + } + tensor_content: "\001\000\000\000\001\000\000\000\001\000\000\000\001\000\000\000" + } + } +} +parameters { + key: "num_steps" + value { + none_value: true + } +} +parameters { + key: "validate_args" + value { + bool_value: false + } +} +parameters { + key: "allow_nan_stats" + value { + bool_value: true + } +} +parameters { + key: "name" + value { + string_value: "Autoregressive" + } +} diff --git a/schema/serde/testdata/batch_broadcast.textproto b/schema/serde/testdata/batch_broadcast.textproto new file mode 100644 index 000000000..b4a698826 --- /dev/null +++ b/schema/serde/testdata/batch_broadcast.textproto @@ -0,0 +1,44 @@ +# proto-file: third_party/py/meridian/proto/mmm/v1/model/meridian/meridian_model.proto +# proto-message: TfpDistribution + +distribution_type: "BatchBroadcast" +parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "scale" + value { scalar_value: 5.0 } + } + parameters { + key: "validate_args" + value { bool_value: false } + } + parameters { + key: "allow_nan_stats" + value { bool_value: true } + } + parameters { + key: "name" + value { string_value: "HalfNormal" } + } + } + } +} +parameters { + key: "with_shape" + value { int_value: 3 } +} +parameters { + key: "to_shape" + value { none_value: true } +} +parameters { + key: "validate_args" + value { bool_value: false } +} +parameters { + key: "name" + value { string_value: "batch_broadcast" } +} diff --git a/schema/serde/testdata/broadcast_dist_proto.textproto b/schema/serde/testdata/broadcast_dist_proto.textproto new file mode 100644 index 000000000..2792458da --- /dev/null +++ b/schema/serde/testdata/broadcast_dist_proto.textproto @@ -0,0 +1,2469 @@ +# proto-file: third_party/py/meridian/proto/mmm/v1/model/meridian/meridian_model.proto +# proto-message: PriorTfpDistributions + +knot_values { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Normal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "knot_values" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "knot_values" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 5 + } + } +} +tau_g_excl_baseline { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Normal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "tau_g_excl_baseline" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "tau_g_excl_baseline" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 9 + } + } +} +beta_m { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "beta_m" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "beta_m" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 6 + } + } +} +beta_rf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "beta_rf" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "beta_rf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 4 + } + } +} +eta_m { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "eta_m" + } + } + parameters { + key: "scale" + value { + scalar_value: 1.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "eta_m" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 6 + } + } +} +eta_rf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "eta_rf" + } + } + parameters { + key: "scale" + value { + scalar_value: 1.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "eta_rf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 4 + } + } +} +gamma_c { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Normal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "gamma_c" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "gamma_c" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 3 + } + } +} +xi_c { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "xi_c" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "xi_c" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 3 + } + } +} +alpha_m { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Uniform" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + scalar_value: 1.00 + } + } + parameters { + key: "low" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "alpha_m" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "alpha_m" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 6 + } + } +} +alpha_rf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Uniform" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + scalar_value: 1.00 + } + } + parameters { + key: "low" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "alpha_rf" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "alpha_rf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 4 + } + } +} +ec_m { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "TruncatedNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + int_value: 10 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.80 + } + } + parameters { + key: "low" + value { + scalar_value: 0.10 + } + } + parameters { + key: "name" + value { + string_value: "ec_m" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.80 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "ec_m" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 6 + } + } +} +ec_rf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "TransformedDistribution" + parameters { + key: "bijector" + value { + bijector_value { + bijector_type: "Shift" + parameters { + key: "name" + value { + string_value: "shift" + } + } + parameters { + key: "shift" + value { + scalar_value: 0.10 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.70 + } + } + parameters { + key: "name" + value { + string_value: "LogNormal" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.40 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "kwargs_split_fn" + value { + function_param { + uses_default: true + } + } + } + parameters { + key: "name" + value { + string_value: "ec_rf" + } + } + parameters { + key: "parameters" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "ec_rf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 4 + } + } +} +slope_m { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Deterministic" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "atol" + value { + none_value: true + } + } + parameters { + key: "loc" + value { + scalar_value: 1.00 + } + } + parameters { + key: "name" + value { + string_value: "slope_m" + } + } + parameters { + key: "rtol" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "slope_m" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 6 + } + } +} +slope_rf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.70 + } + } + parameters { + key: "name" + value { + string_value: "slope_rf" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.40 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "slope_rf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 4 + } + } +} +sigma { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "sigma" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "sigma" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 10 + } + } +} +roi_m { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.20 + } + } + parameters { + key: "name" + value { + string_value: "roi_m" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.90 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "roi_m" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 6 + } + } +} +roi_rf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.20 + } + } + parameters { + key: "name" + value { + string_value: "roi_rf" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.90 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "roi_rf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 4 + } + } +} +beta_om { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "beta_om" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "beta_om" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 2 + } + } +} +beta_orf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "beta_orf" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "beta_orf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 1 + } + } +} +eta_om { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "eta_om" + } + } + parameters { + key: "scale" + value { + scalar_value: 1.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "eta_om" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 2 + } + } +} +eta_orf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "eta_orf" + } + } + parameters { + key: "scale" + value { + scalar_value: 1.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "eta_orf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 1 + } + } +} +gamma_n { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Normal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "gamma_n" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "gamma_n" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 5 + } + } +} +xi_n { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "xi_n" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "xi_n" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 5 + } + } +} +alpha_om { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Uniform" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + scalar_value: 1.00 + } + } + parameters { + key: "low" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "alpha_om" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "alpha_om" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 2 + } + } +} +alpha_orf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Uniform" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + scalar_value: 1.00 + } + } + parameters { + key: "low" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "alpha_orf" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "alpha_orf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 1 + } + } +} +ec_om { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "TruncatedNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + int_value: 10 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.80 + } + } + parameters { + key: "low" + value { + scalar_value: 0.10 + } + } + parameters { + key: "name" + value { + string_value: "ec_om" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.80 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "ec_om" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 2 + } + } +} +ec_orf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "TransformedDistribution" + parameters { + key: "bijector" + value { + bijector_value { + bijector_type: "Shift" + parameters { + key: "name" + value { + string_value: "shift" + } + } + parameters { + key: "shift" + value { + scalar_value: 0.10 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.70 + } + } + parameters { + key: "name" + value { + string_value: "LogNormal" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.40 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "kwargs_split_fn" + value { + function_param { + uses_default: true + } + } + } + parameters { + key: "name" + value { + string_value: "ec_orf" + } + } + parameters { + key: "parameters" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "ec_orf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 1 + } + } +} +slope_om { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Deterministic" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "atol" + value { + none_value: true + } + } + parameters { + key: "loc" + value { + scalar_value: 1.00 + } + } + parameters { + key: "name" + value { + string_value: "slope_om" + } + } + parameters { + key: "rtol" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "slope_om" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 2 + } + } +} +slope_orf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.70 + } + } + parameters { + key: "name" + value { + string_value: "slope_orf" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.40 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "slope_orf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 1 + } + } +} +mroi_m { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "mroi_m" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.50 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "mroi_m" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 6 + } + } +} +mroi_rf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "mroi_rf" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.50 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "mroi_rf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 4 + } + } +} +contribution_m { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Beta" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "concentration0" + value { + scalar_value: 99.00 + } + } + parameters { + key: "concentration1" + value { + scalar_value: 1.00 + } + } + parameters { + key: "force_probs_to_zero_outside_support" + value { + int_value: 0 + } + } + parameters { + key: "name" + value { + string_value: "contribution_m" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "contribution_m" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 6 + } + } +} +contribution_rf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Beta" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "concentration0" + value { + scalar_value: 99.00 + } + } + parameters { + key: "concentration1" + value { + scalar_value: 1.00 + } + } + parameters { + key: "force_probs_to_zero_outside_support" + value { + int_value: 0 + } + } + parameters { + key: "name" + value { + string_value: "contribution_rf" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "contribution_rf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 4 + } + } +} +contribution_om { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Beta" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "concentration0" + value { + scalar_value: 99.00 + } + } + parameters { + key: "concentration1" + value { + scalar_value: 1.00 + } + } + parameters { + key: "force_probs_to_zero_outside_support" + value { + int_value: 0 + } + } + parameters { + key: "name" + value { + string_value: "contribution_om" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "contribution_om" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 2 + } + } +} +contribution_orf { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "Beta" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "concentration0" + value { + scalar_value: 99.00 + } + } + parameters { + key: "concentration1" + value { + scalar_value: 1.00 + } + } + parameters { + key: "force_probs_to_zero_outside_support" + value { + int_value: 0 + } + } + parameters { + key: "name" + value { + string_value: "contribution_orf" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "contribution_orf" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 1 + } + } +} +contribution_n { + distribution_type: "BatchBroadcast" + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "TruncatedNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + scalar_value: 1.00 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "low" + value { + scalar_value: -1.00 + } + } + parameters { + key: "name" + value { + string_value: "contribution_n" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.10 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "name" + value { + string_value: "contribution_n" + } + } + parameters { + key: "to_shape" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + parameters { + key: "with_shape" + value { + int_value: 5 + } + } +} diff --git a/schema/serde/testdata/list_halfnormal.textproto b/schema/serde/testdata/list_halfnormal.textproto new file mode 100644 index 000000000..c4c9ea627 --- /dev/null +++ b/schema/serde/testdata/list_halfnormal.textproto @@ -0,0 +1,31 @@ +# proto-file: third_party/py/meridian/proto/mmm/v1/model/meridian/meridian_model.proto +# proto-message: TfpDistribution + +distribution_type: "HalfNormal" +parameters { + key: "scale" + value { + list_value { + values { scalar_value: 1.0 } + values { scalar_value: 1.1 } + values { scalar_value: 1.2 } + values { scalar_value: 1.3 } + values { scalar_value: 1.4 } + values { scalar_value: 1.5 } + } + } +} +parameters { + key: "name" + value { + string_value: "list_halfnormal" + } +} +parameters { + key: "validate_args" + value: { bool_value: false } +} +parameters { + key: "allow_nan_stats" + value: { bool_value: true } +} diff --git a/schema/serde/testdata/nested_priors_proto.textproto b/schema/serde/testdata/nested_priors_proto.textproto new file mode 100644 index 000000000..fb5e918b7 --- /dev/null +++ b/schema/serde/testdata/nested_priors_proto.textproto @@ -0,0 +1,52 @@ +# proto-file: third_party/py/meridian/proto/mmm/v1/model/meridian/meridian_model.proto +# proto-message: PriorTfpDistributions + +roi_m { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + list_value { + values { + scalar_value: 0.12 + } + values { + scalar_value: 0.23 + } + values { + scalar_value: 0.34 + } + } + } + } + parameters { + key: "scale" + value: { + scalar_value: 0.20 + } + } + parameters { + key: "name" + value { + string_value: "roi_m" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.90 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} diff --git a/schema/serde/testdata/non_default_priors.textproto b/schema/serde/testdata/non_default_priors.textproto new file mode 100644 index 000000000..a3e6b3463 --- /dev/null +++ b/schema/serde/testdata/non_default_priors.textproto @@ -0,0 +1,75 @@ +# proto-file: third_party/py/meridian/proto/mmm/v1/model/meridian/meridian_model.proto +# proto-message: PriorDistributions + +knot_values { + name: "knot_values" + normal { + locs: 2 + locs: 3 + locs: 4 + scales: 7 + scales: 8 + scales: 9 + } +} +tau_g_excl_baseline { + name: "tau_g_excl_baseline" + normal { + locs: 3 + scales: 10 + } +} +beta_m { + name: "beta_m" + half_normal { + scales: 4 + } +} +beta_rf { + name: "beta_rf" + half_normal { + scales: 2 + scales: 3 + } +} +alpha_rf { + name: "alpha_rf" + uniform { + lows: 0.1 + lows: 0.2 + highs: 0.8 + highs: 0.9 + } +} +alpha_om { + name: "alpha_om" + uniform { + lows: 0.1 + highs: 0.9 + } +} +alpha_orf { + name: "alpha_orf" + uniform { + low: 0.1 + high: 0.9 + } +} +ec_rf { + name: "ec_rf" + transformed { + distribution { + name: "LogNormal" + log_normal { + locs: 0.70 + scales: 0.40 + } + } + bijector { + name: "scale" + scale { + scales: 11 + } + } + } +} diff --git a/schema/serde/testdata/prior_dist_proto.textproto b/schema/serde/testdata/prior_dist_proto.textproto new file mode 100644 index 000000000..8fa62810c --- /dev/null +++ b/schema/serde/testdata/prior_dist_proto.textproto @@ -0,0 +1,1317 @@ +# proto-file: third_party/py/meridian/proto/mmm/v1/model/meridian/meridian_model.proto +# proto-message: PriorTfpDistributions + +knot_values { + distribution_type: "Normal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "knot_values" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +tau_g_excl_baseline { + distribution_type: "Normal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "tau_g_excl_baseline" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +beta_m { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "beta_m" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +beta_rf { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "beta_rf" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +eta_m { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "eta_m" + } + } + parameters { + key: "scale" + value { + scalar_value: 1.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +eta_rf { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "eta_rf" + } + } + parameters { + key: "scale" + value { + scalar_value: 1.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +gamma_c { + distribution_type: "Normal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "gamma_c" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +xi_c { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "xi_c" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +alpha_m { + distribution_type: "Uniform" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + scalar_value: 1.00 + } + } + parameters { + key: "low" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "alpha_m" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +alpha_rf { + distribution_type: "Uniform" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + scalar_value: 1.00 + } + } + parameters { + key: "low" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "alpha_rf" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +ec_m { + distribution_type: "TruncatedNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + int_value: 10 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.80 + } + } + parameters { + key: "low" + value { + scalar_value: 0.10 + } + } + parameters { + key: "name" + value { + string_value: "ec_m" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.80 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +ec_rf { + distribution_type: "TransformedDistribution" + parameters { + key: "bijector" + value { + bijector_value { + bijector_type: "Shift" + parameters { + key: "name" + value { + string_value: "shift" + } + } + parameters { + key: "shift" + value { + scalar_value: 0.10 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.70 + } + } + parameters { + key: "name" + value { + string_value: "LogNormal" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.40 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "kwargs_split_fn" + value { + function_param { + uses_default: true + } + } + } + parameters { + key: "name" + value { + string_value: "ec_rf" + } + } + parameters { + key: "parameters" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +slope_m { + distribution_type: "Deterministic" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "atol" + value { + none_value: true + } + } + parameters { + key: "loc" + value { + scalar_value: 1.00 + } + } + parameters { + key: "name" + value { + string_value: "slope_m" + } + } + parameters { + key: "rtol" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +slope_rf { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.70 + } + } + parameters { + key: "name" + value { + string_value: "slope_rf" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.40 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +sigma { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "sigma" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +roi_m { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.20 + } + } + parameters { + key: "name" + value { + string_value: "roi_m" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.90 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +roi_rf { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.20 + } + } + parameters { + key: "name" + value { + string_value: "roi_rf" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.90 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +beta_om { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "beta_om" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +beta_orf { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "beta_orf" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +eta_om { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "eta_om" + } + } + parameters { + key: "scale" + value { + scalar_value: 1.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +eta_orf { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "eta_orf" + } + } + parameters { + key: "scale" + value { + scalar_value: 1.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +gamma_n { + distribution_type: "Normal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "gamma_n" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +xi_n { + distribution_type: "HalfNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "name" + value { + string_value: "xi_n" + } + } + parameters { + key: "scale" + value { + scalar_value: 5.00 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +alpha_om { + distribution_type: "Uniform" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + scalar_value: 1.00 + } + } + parameters { + key: "low" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "alpha_om" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +alpha_orf { + distribution_type: "Uniform" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + scalar_value: 1.00 + } + } + parameters { + key: "low" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "alpha_orf" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +ec_om { + distribution_type: "TruncatedNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + int_value: 10 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.80 + } + } + parameters { + key: "low" + value { + scalar_value: 0.10 + } + } + parameters { + key: "name" + value { + string_value: "ec_om" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.80 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +ec_orf { + distribution_type: "TransformedDistribution" + parameters { + key: "bijector" + value { + bijector_value { + bijector_type: "Shift" + parameters { + key: "name" + value { + string_value: "shift" + } + } + parameters { + key: "shift" + value { + scalar_value: 0.10 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.70 + } + } + parameters { + key: "name" + value { + string_value: "LogNormal" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.40 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } + } + } + } + parameters { + key: "kwargs_split_fn" + value { + function_param { + uses_default: true + } + } + } + parameters { + key: "name" + value { + string_value: "ec_orf" + } + } + parameters { + key: "parameters" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +slope_om { + distribution_type: "Deterministic" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "atol" + value { + none_value: true + } + } + parameters { + key: "loc" + value { + scalar_value: 1.00 + } + } + parameters { + key: "name" + value { + string_value: "slope_om" + } + } + parameters { + key: "rtol" + value { + none_value: true + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +slope_orf { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.70 + } + } + parameters { + key: "name" + value { + string_value: "slope_orf" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.40 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +mroi_m { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "mroi_m" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.50 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +mroi_rf { + distribution_type: "LogNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "name" + value { + string_value: "mroi_rf" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.50 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +contribution_m { + distribution_type: "Beta" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "concentration0" + value { + scalar_value: 99.00 + } + } + parameters { + key: "concentration1" + value { + scalar_value: 1.00 + } + } + parameters { + key: "force_probs_to_zero_outside_support" + value { + int_value: 0 + } + } + parameters { + key: "name" + value { + string_value: "contribution_m" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +contribution_rf { + distribution_type: "Beta" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "concentration0" + value { + scalar_value: 99.00 + } + } + parameters { + key: "concentration1" + value { + scalar_value: 1.00 + } + } + parameters { + key: "force_probs_to_zero_outside_support" + value { + int_value: 0 + } + } + parameters { + key: "name" + value { + string_value: "contribution_rf" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +contribution_om { + distribution_type: "Beta" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "concentration0" + value { + scalar_value: 99.00 + } + } + parameters { + key: "concentration1" + value { + scalar_value: 1.00 + } + } + parameters { + key: "force_probs_to_zero_outside_support" + value { + int_value: 0 + } + } + parameters { + key: "name" + value { + string_value: "contribution_om" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +contribution_orf { + distribution_type: "Beta" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "concentration0" + value { + scalar_value: 99.00 + } + } + parameters { + key: "concentration1" + value { + scalar_value: 1.00 + } + } + parameters { + key: "force_probs_to_zero_outside_support" + value { + int_value: 0 + } + } + parameters { + key: "name" + value { + string_value: "contribution_orf" + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} +contribution_n { + distribution_type: "TruncatedNormal" + parameters { + key: "allow_nan_stats" + value { + int_value: 1 + } + } + parameters { + key: "high" + value { + scalar_value: 1.00 + } + } + parameters { + key: "loc" + value { + scalar_value: 0.00 + } + } + parameters { + key: "low" + value { + scalar_value: -1.00 + } + } + parameters { + key: "name" + value { + string_value: "contribution_n" + } + } + parameters { + key: "scale" + value { + scalar_value: 0.10 + } + } + parameters { + key: "validate_args" + value { + int_value: 0 + } + } +} diff --git a/schema/serde/testdata/scalar_halfnormal.textproto b/schema/serde/testdata/scalar_halfnormal.textproto new file mode 100644 index 000000000..a66361095 --- /dev/null +++ b/schema/serde/testdata/scalar_halfnormal.textproto @@ -0,0 +1,20 @@ +# proto-file: third_party/py/meridian/proto/mmm/v1/model/meridian/meridian_model.proto +# proto-message: TfpDistribution + +distribution_type: "HalfNormal" +parameters { + key: "scale" + value { scalar_value: 1.0 } +} +parameters { + key: "name" + value { string_value: "scalar_halfnormal" } +} +parameters { + key: "validate_args" + value: { bool_value: false } +} +parameters { + key: "allow_nan_stats" + value: { bool_value: true } +} diff --git a/schema/serde/testdata/transformed_distribution_reciprocal.textproto b/schema/serde/testdata/transformed_distribution_reciprocal.textproto new file mode 100644 index 000000000..7f3c34445 --- /dev/null +++ b/schema/serde/testdata/transformed_distribution_reciprocal.textproto @@ -0,0 +1,70 @@ +# proto-file: third_party/py/meridian/proto/mmm/v1/model/meridian/meridian_model.proto +# proto-message: TfpDistribution + +distribution_type: "TransformedDistribution" +parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "loc" + value { + scalar_value: 0.7 + } + } + parameters { + key: "scale" + value { + scalar_value: 0.4 + } + } + } + } +} +parameters { + key: "bijector" + value { + bijector_value { + bijector_type: "Reciprocal" + parameters: { + key: "name" + value: { + string_value: "reciprocal" + } + } + parameters: { + key: "validate_args" + value: { + bool_value: false + } + } + } + } +} +parameters { + key: "kwargs_split_fn" + value { + function_param { + uses_default: true + } + } +} +parameters { + key: "validate_args" + value { + bool_value: false + } +} +parameters { + key: "parameters" + value { + none_value: true + } +} +parameters { + key: "name" + value { + string_value: "transformed_distribution_reciprocal" + } +} diff --git a/schema/serde/testdata/transformed_distribution_shift.textproto b/schema/serde/testdata/transformed_distribution_shift.textproto new file mode 100644 index 000000000..9cd420775 --- /dev/null +++ b/schema/serde/testdata/transformed_distribution_shift.textproto @@ -0,0 +1,76 @@ +# proto-file: third_party/py/meridian/proto/mmm/v1/model/meridian/meridian_model.proto +# proto-message: TfpDistribution + +distribution_type: "TransformedDistribution" +parameters { + key: "distribution" + value { + distribution_value { + distribution_type: "LogNormal" + parameters { + key: "loc" + value { + scalar_value: 0.7 + } + } + parameters { + key: "scale" + value { + scalar_value: 0.4 + } + } + } + } +} +parameters { + key: "bijector" + value { + bijector_value { + bijector_type: "Shift" + parameters: { + key: "name" + value: { + string_value: "shift" + } + } + parameters: { + key: "validate_args" + value: { + bool_value: false + } + } + parameters: { + key: "shift" + value: { + scalar_value: 0.1 + } + } + } + } +} +parameters { + key: "kwargs_split_fn" + value { + function_param { + uses_default: true + } + } +} +parameters { + key: "validate_args" + value { + bool_value: false + } +} +parameters { + key: "parameters" + value { + none_value: true + } +} +parameters { + key: "name" + value { + string_value: "transformed_distribution_shift" + } +} diff --git a/schema/serde/testdata/truncated_normal.textproto b/schema/serde/testdata/truncated_normal.textproto new file mode 100644 index 000000000..92707c45c --- /dev/null +++ b/schema/serde/testdata/truncated_normal.textproto @@ -0,0 +1,32 @@ +# proto-file: third_party/py/meridian/proto/mmm/v1/model/meridian/meridian_model.proto +# proto-message: TfpDistribution + +distribution_type: "TruncatedNormal" +parameters { + key: "loc" + value { scalar_value: 0.7 } +} +parameters { + key: "scale" + value { scalar_value: 0.4 } +} +parameters { + key: "low" + value { scalar_value: 0.1 } +} +parameters { + key: "high" + value { scalar_value: 10 } +} +parameters { + key: "name" + value: { string_value: "truncated_normal" } +} +parameters { + key: "validate_args" + value: { bool_value: false } +} +parameters { + key: "allow_nan_stats" + value: { bool_value: true } +} diff --git a/schema/test_data.py b/schema/test_data.py new file mode 100644 index 000000000..bf27124ca --- /dev/null +++ b/schema/test_data.py @@ -0,0 +1,380 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test data for MMM proto generator.""" + +from collections.abc import Sequence +import datetime + +from mmm.v1 import mmm_pb2 as mmm_pb +from mmm.v1.common import date_interval_pb2 as date_interval_pb +from mmm.v1.fit import model_fit_pb2 as fit_pb +from mmm.v1.marketing.analysis import marketing_analysis_pb2 +from mmm.v1.marketing.optimization import budget_optimization_pb2 as budget_pb +from mmm.v1.marketing.optimization import reach_frequency_optimization_pb2 as rf_pb +from schema.processors import budget_optimization_processor +from schema.processors import marketing_processor +from schema.processors import model_fit_processor +from schema.processors import model_processor +from schema.processors import reach_frequency_optimization_processor as rf_opt_processor + +from google.type import date_pb2 + +# Weekly dates from 2022-11-21 to 2024-01-01. +ALL_TIMES_IN_MERIDIAN = ( + '2022-11-21', + '2022-11-28', + '2022-12-05', + '2022-12-12', + '2022-12-19', + '2022-12-26', + '2023-01-02', + '2023-01-09', + '2023-01-16', + '2023-01-23', + '2023-01-30', + '2023-02-06', + '2023-02-13', + '2023-02-20', + '2023-02-27', + '2023-03-06', + '2023-03-13', + '2023-03-20', + '2023-03-27', + '2023-04-03', + '2023-04-10', + '2023-04-17', + '2023-04-24', + '2023-05-01', + '2023-05-08', + '2023-05-15', + '2023-05-22', + '2023-05-29', + '2023-06-05', + '2023-06-12', + '2023-06-19', + '2023-06-26', + '2023-07-03', + '2023-07-10', + '2023-07-17', + '2023-07-24', + '2023-07-31', + '2023-08-07', + '2023-08-14', + '2023-08-21', + '2023-08-28', + '2023-09-04', + '2023-09-11', + '2023-09-18', + '2023-09-25', + '2023-10-02', + '2023-10-09', + '2023-10-16', + '2023-10-23', + '2023-10-30', + '2023-11-06', + '2023-11-13', + '2023-11-20', + '2023-11-27', + '2023-12-04', + '2023-12-11', + '2023-12-18', + '2023-12-25', + '2024-01-01', +) + +ALL_TIME_BUCKET_DATED_SPECS = ( + # All + model_processor.DatedSpec( + start_date=datetime.date(2022, 11, 21), + end_date=datetime.date(2024, 1, 8), + date_interval_tag='ALL', + ), + # Monthly buckets + model_processor.DatedSpec( + start_date=datetime.date(2022, 12, 5), + end_date=datetime.date(2023, 1, 2), + date_interval_tag='Y2022 Dec', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 1, 2), + end_date=datetime.date(2023, 2, 6), + date_interval_tag='Y2023 Jan', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 2, 6), + end_date=datetime.date(2023, 3, 6), + date_interval_tag='Y2023 Feb', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 3, 6), + end_date=datetime.date(2023, 4, 3), + date_interval_tag='Y2023 Mar', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 4, 3), + end_date=datetime.date(2023, 5, 1), + date_interval_tag='Y2023 Apr', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 5, 1), + end_date=datetime.date(2023, 6, 5), + date_interval_tag='Y2023 May', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 6, 5), + end_date=datetime.date(2023, 7, 3), + date_interval_tag='Y2023 Jun', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 7, 3), + end_date=datetime.date(2023, 8, 7), + date_interval_tag='Y2023 Jul', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 8, 7), + end_date=datetime.date(2023, 9, 4), + date_interval_tag='Y2023 Aug', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 9, 4), + end_date=datetime.date(2023, 10, 2), + date_interval_tag='Y2023 Sep', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 10, 2), + end_date=datetime.date(2023, 11, 6), + date_interval_tag='Y2023 Oct', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 11, 6), + end_date=datetime.date(2023, 12, 4), + date_interval_tag='Y2023 Nov', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 12, 4), + end_date=datetime.date(2024, 1, 1), + date_interval_tag='Y2023 Dec', + ), + # Quarterly buckets + model_processor.DatedSpec( + start_date=datetime.date(2023, 1, 2), + end_date=datetime.date(2023, 4, 3), + date_interval_tag='Y2023 Q1', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 4, 3), + end_date=datetime.date(2023, 7, 3), + date_interval_tag='Y2023 Q2', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 7, 3), + end_date=datetime.date(2023, 10, 2), + date_interval_tag='Y2023 Q3', + ), + model_processor.DatedSpec( + start_date=datetime.date(2023, 10, 2), + end_date=datetime.date(2024, 1, 1), + date_interval_tag='Y2023 Q4', + ), + # Yearly buckets + model_processor.DatedSpec( + start_date=datetime.date(2023, 1, 2), + end_date=datetime.date(2024, 1, 1), + date_interval_tag='Y2023', + ), +) + + +def _dated_spec_to_date_interval( + spec: model_processor.DatedSpec, +) -> date_interval_pb.DateInterval: + if spec.start_date is None or spec.end_date is None: + raise ValueError('Start date or end date is None.') + + return date_interval_pb.DateInterval( + start_date=date_pb2.Date( + year=spec.start_date.year, + month=spec.start_date.month, + day=spec.start_date.day, + ), + end_date=date_pb2.Date( + year=spec.end_date.year, + month=spec.end_date.month, + day=spec.end_date.day, + ), + tag=spec.date_interval_tag, + ) + + +class FakeModelFitProcessor( + model_processor.ModelProcessor[ + model_fit_processor.ModelFitSpec, fit_pb.ModelFit + ] +): + """Fake ModelFitProcessor for testing.""" + + def __init__(self, trained_model: model_processor.TrainedModel): + self._trained_model = trained_model + + @classmethod + def spec_type(cls): + return model_fit_processor.ModelFitSpec + + @classmethod + def output_type(cls): + return fit_pb.ModelFit + + def execute( + self, specs: Sequence[model_fit_processor.ModelFitSpec] + ) -> fit_pb.ModelFit: + return fit_pb.ModelFit() + + def _set_output(self, output: mmm_pb.Mmm, result: fit_pb.ModelFit): + output.model_fit.CopyFrom(result) + + +class FakeBudgetOptimizationProcessor( + model_processor.ModelProcessor[ + budget_optimization_processor.BudgetOptimizationSpec, + budget_pb.BudgetOptimization, + ] +): + """Fake BudgetOptimizationProcessor for testing.""" + + def __init__(self, trained_model: model_processor.TrainedModel): + self._trained_model = trained_model + + @classmethod + def spec_type(cls): + return budget_optimization_processor.BudgetOptimizationSpec + + @classmethod + def output_type(cls): + return budget_pb.BudgetOptimization + + def execute( + self, + specs: Sequence[budget_optimization_processor.BudgetOptimizationSpec], + ) -> budget_pb.BudgetOptimization: + results = [] + for spec in specs: + result = budget_pb.BudgetOptimizationResult( + name=spec.optimization_name, + spec=budget_pb.BudgetOptimizationSpec( + date_interval=_dated_spec_to_date_interval(spec) + ), + incremental_outcome_grid=budget_pb.IncrementalOutcomeGrid( + name=spec.grid_name + ), + ) + if spec.group_id: + result.group_id = spec.group_id + results.append(result) + + return budget_pb.BudgetOptimization(results=results) + + def _set_output( + self, output: mmm_pb.Mmm, result: budget_pb.BudgetOptimization + ): + output.marketing_optimization.budget_optimization.CopyFrom(result) + + +class FakeReachFrequencyOptimizationProcessor( + model_processor.ModelProcessor[ + rf_opt_processor.ReachFrequencyOptimizationSpec, + rf_pb.ReachFrequencyOptimization, + ] +): + """Fake ReachFrequencyOptimizationProcessor for testing.""" + + def __init__(self, trained_model: model_processor.TrainedModel): + self._trained_model = trained_model + + @classmethod + def spec_type(cls): + return rf_opt_processor.ReachFrequencyOptimizationSpec + + @classmethod + def output_type(cls): + return rf_pb.ReachFrequencyOptimization + + def execute( + self, + specs: Sequence[rf_opt_processor.ReachFrequencyOptimizationSpec], + ) -> rf_pb.ReachFrequencyOptimization: + results = [] + for spec in specs: + result = rf_pb.ReachFrequencyOptimizationResult( + name=spec.optimization_name, + spec=rf_pb.ReachFrequencyOptimizationSpec( + date_interval=_dated_spec_to_date_interval(spec) + ), + frequency_outcome_grid=rf_pb.FrequencyOutcomeGrid( + name=spec.grid_name + ), + ) + if spec.group_id: + result.group_id = spec.group_id + results.append(result) + + return rf_pb.ReachFrequencyOptimization(results=results) + + def _set_output( + self, + output: mmm_pb.Mmm, + result: rf_pb.ReachFrequencyOptimization, + ): + output.marketing_optimization.reach_frequency_optimization.CopyFrom(result) + + +class FakeMarketingProcessor( + model_processor.ModelProcessor[ + marketing_processor.MarketingAnalysisSpec, + marketing_analysis_pb2.MarketingAnalysisList, + ] +): + """Fake MarketingProcessor for testing.""" + + def __init__(self, trained_model: model_processor.TrainedModel): + self._trained_model = trained_model + + @classmethod + def spec_type(cls): + return marketing_processor.MarketingAnalysisSpec + + @classmethod + def output_type(cls): + return marketing_analysis_pb2.MarketingAnalysisList + + def execute( + self, specs: Sequence[marketing_processor.MarketingAnalysisSpec] + ) -> marketing_analysis_pb2.MarketingAnalysisList: + marketing_analyses = [] + for spec in specs: + marketing_analysis = marketing_analysis_pb2.MarketingAnalysis( + date_interval=_dated_spec_to_date_interval(spec) + ) + marketing_analyses.append(marketing_analysis) + + return marketing_analysis_pb2.MarketingAnalysisList( + marketing_analyses=marketing_analyses + ) + + def _set_output( + self, + output: mmm_pb.Mmm, + result: marketing_analysis_pb2.MarketingAnalysisList, + ): + output.marketing_analysis_list.CopyFrom(result) diff --git a/schema/utils/__init__.py b/schema/utils/__init__.py new file mode 100644 index 000000000..288980e95 --- /dev/null +++ b/schema/utils/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module containing MMM schema util functions.""" + +from schema.utils import date_range_bucketing +from schema.utils import time_record diff --git a/schema/utils/date_range_bucketing.py b/schema/utils/date_range_bucketing.py new file mode 100644 index 000000000..d842a43ff --- /dev/null +++ b/schema/utils/date_range_bucketing.py @@ -0,0 +1,117 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper classes for generating date intervals for various time buckets.""" + +import abc +from collections.abc import Iterator, Sequence +import datetime +from typing import TypeAlias + + +__all__ = [ + "DateRangeBucketer", + "MonthlyDateRangeGenerator", + "QuarterlyDateRangeGenerator", + "YearlyDateRangeGenerator", +] + + +DateInterval: TypeAlias = tuple[datetime.date, datetime.date] + + +class DateRangeBucketer(abc.ABC): + """Generates `DateInterval` protos over a range of dates.""" + + def __init__( + self, + input_dates: Sequence[datetime.date], + ): + """Initializes the DateRangeBucketer with a sequence of dates. + + Args: + input_dates: A sequence of `datetime.date` objects representing the range + of dates to generate intervals for. + """ + if not all( + input_dates[i] < input_dates[i + 1] for i in range(len(input_dates) - 1) + ): + raise ValueError("`input_dates` must be strictly ascending dates.") + + self._input_dates = input_dates + + @abc.abstractmethod + def generate_date_intervals(self) -> Iterator[DateInterval]: + """Generates `DateInterval` protos for the class's input dates. + + Each interval represents a month, quarter, or year, depending on the + instance of this class. An interval is excluded if the start date is not the + first available date (in `self._input_dates`) for the time bucket. The last + interval in `self._input_dates` is excluded in all cases. + + Returns: + An iterator over generated `TimeInterval`s. + """ + raise NotImplementedError() + + +class MonthlyDateRangeGenerator(DateRangeBucketer): + """Generates monthly date intervals.""" + + def generate_date_intervals(self) -> Iterator[DateInterval]: + start_date = self._input_dates[0] + + for date in self._input_dates: + if date.month != start_date.month: + if start_date.day <= 7: + yield (start_date, date) + + start_date = date + + +class QuarterlyDateRangeGenerator(DateRangeBucketer): + """Generates quarterly date intervals.""" + + def generate_date_intervals(self) -> Iterator[DateInterval]: + start_date = self._input_dates[0] + for date in self._input_dates: + start_date_quarter_number = (start_date.month - 1) // 3 + 1 + current_date_quarter_number = (date.month - 1) // 3 + 1 + + if start_date_quarter_number != current_date_quarter_number: + # The interval is only included if the start date is the first date of + # the quarter that's present in `self._input_dates`. We can detect this + # date by checking whether it's in the first month of the quarter and + # falls in the first seven days of the month. + if ( + start_date.day <= 7 + and start_date.month == ((start_date_quarter_number - 1) * 3) + 1 + ): + yield (start_date, date) + + start_date = date + + +class YearlyDateRangeGenerator(DateRangeBucketer): + """Generates yearly date intervals.""" + + def generate_date_intervals(self) -> Iterator[DateInterval]: + start_date = self._input_dates[0] + + for date in self._input_dates: + if date.year != start_date.year: + if start_date.day <= 7 and start_date.month == 1: + yield (start_date, date) + + start_date = date diff --git a/schema/utils/date_range_bucketing_test.py b/schema/utils/date_range_bucketing_test.py new file mode 100644 index 000000000..f0d06f0df --- /dev/null +++ b/schema/utils/date_range_bucketing_test.py @@ -0,0 +1,261 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +from absl.testing import absltest +from schema.utils import date_range_bucketing + + +class MonthlyDateRangeGeneratorTest(absltest.TestCase): + + def test_generate_date_intervals_skips_first_interval_if_not_start_of_month( + self, + ): + input_dates = [ + datetime.date(2023, 1, 15), + datetime.date(2023, 1, 22), + datetime.date(2023, 1, 29), + datetime.date(2023, 2, 6), + datetime.date(2023, 2, 13), + datetime.date(2023, 2, 20), + datetime.date(2023, 2, 27), + datetime.date(2023, 3, 6), + ] + expected_date_intervals = [ + (datetime.date(2023, 2, 6), datetime.date(2023, 3, 6)), + ] + + date_intervals = list( + date_range_bucketing.MonthlyDateRangeGenerator( + input_dates + ).generate_date_intervals() + ) + + self.assertSequenceEqual(date_intervals, expected_date_intervals) + + def test_generate_date_intervals_skips_last_interval( + self, + ): + input_dates = [ + datetime.date(2023, 1, 1), + datetime.date(2023, 1, 8), + datetime.date(2023, 1, 15), + datetime.date(2023, 1, 22), + datetime.date(2023, 1, 29), + datetime.date(2023, 2, 6), + datetime.date(2023, 2, 13), + datetime.date(2023, 2, 20), + datetime.date(2023, 2, 27), + datetime.date(2023, 3, 6), + datetime.date(2023, 3, 13), + ] + expected_date_intervals = [ + (datetime.date(2023, 1, 1), datetime.date(2023, 2, 6)), + (datetime.date(2023, 2, 6), datetime.date(2023, 3, 6)), + ] + + date_intervals = list( + date_range_bucketing.MonthlyDateRangeGenerator( + input_dates + ).generate_date_intervals() + ) + + self.assertSequenceEqual(date_intervals, expected_date_intervals) + + +class QuarterlyDateRangeGeneratorTest(absltest.TestCase): + + def test_generate_date_intervals_skips_first_interval_if_not_start_of_qtr( + self, + ): + input_dates = [ + datetime.date(2023, 2, 1), + datetime.date(2023, 3, 1), + datetime.date(2023, 4, 1), + datetime.date(2023, 5, 1), + datetime.date(2023, 6, 1), + datetime.date(2023, 7, 1), + datetime.date(2023, 8, 1), + datetime.date(2023, 9, 1), + datetime.date(2023, 10, 1), + ] + expected_date_intervals = [ + (datetime.date(2023, 4, 1), datetime.date(2023, 7, 1)), + (datetime.date(2023, 7, 1), datetime.date(2023, 10, 1)), + ] + + date_intervals = list( + date_range_bucketing.QuarterlyDateRangeGenerator( + input_dates + ).generate_date_intervals() + ) + + self.assertSequenceEqual(date_intervals, expected_date_intervals) + + def test_generate_date_intervals_skips_last_interval( + self, + ): + input_dates = [ + datetime.date(2023, 1, 1), + datetime.date(2023, 2, 1), + datetime.date(2023, 3, 1), + datetime.date(2023, 4, 1), + datetime.date(2023, 5, 1), + datetime.date(2023, 6, 1), + datetime.date(2023, 7, 1), + datetime.date(2023, 8, 1), + ] + expected_date_intervals = [ + (datetime.date(2023, 1, 1), datetime.date(2023, 4, 1)), + (datetime.date(2023, 4, 1), datetime.date(2023, 7, 1)), + ] + + date_intervals = list( + date_range_bucketing.QuarterlyDateRangeGenerator( + input_dates + ).generate_date_intervals() + ) + + self.assertSequenceEqual(date_intervals, expected_date_intervals) + + +class YearlyDateRangeGeneratorTest(absltest.TestCase): + + def test_generate_date_intervals_returns_empty_when_full_year_is_not_covered( + self, + ): + input_dates = [ + datetime.date(2022, 11, 1), + datetime.date(2022, 12, 1), + datetime.date(2023, 1, 1), + ] + + date_intervals = list( + date_range_bucketing.YearlyDateRangeGenerator( + input_dates + ).generate_date_intervals() + ) + + self.assertEmpty(date_intervals) + + def test_generate_date_intervals_skips_first_interval_if_not_first_month( + self, + ): + input_dates = [ + datetime.date(2022, 12, 1), + datetime.date(2023, 1, 1), + datetime.date(2023, 2, 1), + datetime.date(2023, 3, 1), + datetime.date(2023, 4, 1), + datetime.date(2023, 5, 1), + datetime.date(2023, 6, 1), + datetime.date(2023, 7, 1), + datetime.date(2023, 8, 1), + datetime.date(2023, 9, 1), + datetime.date(2023, 10, 1), + datetime.date(2023, 11, 1), + datetime.date(2023, 12, 1), + datetime.date(2024, 1, 1), + ] + expected_date_intervals = [ + (datetime.date(2023, 1, 1), datetime.date(2024, 1, 1)), + ] + + date_intervals = list( + date_range_bucketing.YearlyDateRangeGenerator( + input_dates + ).generate_date_intervals() + ) + + self.assertSequenceEqual(date_intervals, expected_date_intervals) + + def test_generate_date_intervals_skips_last_interval( + self, + ): + input_dates = [ + datetime.date(2023, 1, 1), + datetime.date(2023, 2, 1), + datetime.date(2023, 3, 1), + datetime.date(2023, 4, 1), + datetime.date(2023, 5, 1), + datetime.date(2023, 6, 1), + datetime.date(2023, 7, 1), + datetime.date(2023, 8, 1), + datetime.date(2023, 9, 1), + datetime.date(2023, 10, 1), + datetime.date(2023, 11, 1), + datetime.date(2023, 12, 1), + datetime.date(2024, 1, 1), + datetime.date(2024, 2, 1), + ] + expected_date_intervals = [ + (datetime.date(2023, 1, 1), datetime.date(2024, 1, 1)), + ] + + date_intervals = list( + date_range_bucketing.YearlyDateRangeGenerator( + input_dates + ).generate_date_intervals() + ) + + self.assertSequenceEqual(date_intervals, expected_date_intervals) + + def test_generate_date_intervals_produces_two_intervals_for_two_full_years( + self, + ): + input_dates = [ + datetime.date(2023, 1, 1), + datetime.date(2023, 2, 1), + datetime.date(2023, 3, 1), + datetime.date(2023, 4, 1), + datetime.date(2023, 5, 1), + datetime.date(2023, 6, 1), + datetime.date(2023, 7, 1), + datetime.date(2023, 8, 1), + datetime.date(2023, 9, 1), + datetime.date(2023, 10, 1), + datetime.date(2023, 11, 1), + datetime.date(2023, 12, 1), + datetime.date(2024, 1, 1), + datetime.date(2024, 2, 1), + datetime.date(2024, 3, 1), + datetime.date(2024, 4, 1), + datetime.date(2024, 5, 1), + datetime.date(2024, 6, 1), + datetime.date(2024, 7, 1), + datetime.date(2024, 8, 1), + datetime.date(2024, 9, 1), + datetime.date(2024, 10, 1), + datetime.date(2024, 11, 1), + datetime.date(2024, 12, 1), + datetime.date(2025, 1, 1), + datetime.date(2025, 2, 1), + ] + expected_date_intervals = [ + (datetime.date(2023, 1, 1), datetime.date(2024, 1, 1)), + (datetime.date(2024, 1, 1), datetime.date(2025, 1, 1)), + ] + + date_intervals = list( + date_range_bucketing.YearlyDateRangeGenerator( + input_dates + ).generate_date_intervals() + ) + + self.assertSequenceEqual(date_intervals, expected_date_intervals) + + +if __name__ == "__main__": + absltest.main() diff --git a/schema/utils/time_record.py b/schema/utils/time_record.py new file mode 100644 index 000000000..69e02cdf0 --- /dev/null +++ b/schema/utils/time_record.py @@ -0,0 +1,156 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for time-related operations.""" + +from collections.abc import Mapping, MutableMapping, Sequence +import datetime + +from meridian import constants +from mmm.v1.common import date_interval_pb2 +import pandas as pd + +from google.type import date_pb2 + + +__all__ = [ + "convert_times_to_date_intervals", + "create_date_interval_pb", + "dates_from_date_interval_proto", +] + + +def convert_times_to_date_intervals( + times: Sequence[str] | Sequence[datetime.date] | pd.DatetimeIndex, +) -> Mapping[str, date_interval_pb2.DateInterval]: + """Creates a date interval for each time in `times` as dict values. + + Args: + times: Sequence of date strings in YYYY-MM-DD format. + + Returns: + Mapping that maps each time in `times` (string form) to the corresponding + date interval. + + Raises: + ValueError: If `times` has fewer than 2 elements or if the interval length + between each time is not consistent. + """ + if len(times) < 2: + raise ValueError("There must be at least 2 time points.") + + if isinstance(times, pd.DatetimeIndex): + datetimes = [ + datetime.datetime(year=time.year, month=time.month, day=time.day) + for time in times + ] + else: + datetimes = [ + datetime.datetime.strptime(time, constants.DATE_FORMAT) + if isinstance(time, str) + else time + for time in times + ] + + interval_length = _compute_interval_length( + start_date=datetimes[0], + end_date=datetimes[1], + ) + time_to_date_interval: MutableMapping[str, date_interval_pb2.DateInterval] = ( + {} + ) + + for i, start_date in enumerate(datetimes): + if i == len(datetimes) - 1: + end_date = start_date + datetime.timedelta(days=interval_length) + else: + end_date = datetimes[i + 1] + current_interval_length = _compute_interval_length( + start_date=start_date, + end_date=end_date, + ) + + if current_interval_length != interval_length: + raise ValueError( + "Interval length between selected times must be consistent." + ) + + date_interval = create_date_interval_pb(start_date, end_date) + time_to_date_interval[start_date.strftime(constants.DATE_FORMAT)] = ( + date_interval + ) + + return time_to_date_interval + + +def create_date_interval_pb( + start_date: datetime.date, end_date: datetime.date, tag: str = "" +) -> date_interval_pb2.DateInterval: + """Creates a `DateInterval` proto for the given start and end dates. + + Args: + start_date: A datetime object representing the start date. + end_date: A datetime object representing the end date. + tag: An optional tag to identify the date interval. + + Returns: + Returns a date interval proto wrapping the start/end dates. + """ + start_date_proto = date_pb2.Date( + year=start_date.year, + month=start_date.month, + day=start_date.day, + ) + end_date_proto = date_pb2.Date( + year=end_date.year, + month=end_date.month, + day=end_date.day, + ) + return date_interval_pb2.DateInterval( + start_date=start_date_proto, + end_date=end_date_proto, + tag=tag, + ) + + +def dates_from_date_interval_proto( + date_interval: date_interval_pb2.DateInterval, +) -> tuple[datetime.date, datetime.date]: + """Returns a tuple of `[start, end)` date range from a `DateInterval` proto.""" + start_date = datetime.date( + date_interval.start_date.year, + date_interval.start_date.month, + date_interval.start_date.day, + ) + end_date = datetime.date( + date_interval.end_date.year, + date_interval.end_date.month, + date_interval.end_date.day, + ) + return start_date, end_date + + +def _compute_interval_length( + start_date: datetime.datetime, end_date: datetime.datetime +) -> int: + """Computes the number of days between `start_date` and `end_date`. + + Args: + start_date: A datetime object representing the start date. + end_date: A datetime object representing the end date. + + Returns: + The number of days between the given dates. + """ + return end_date.toordinal() - start_date.toordinal() diff --git a/schema/utils/time_record_test.py b/schema/utils/time_record_test.py new file mode 100644 index 000000000..0cc329b69 --- /dev/null +++ b/schema/utils/time_record_test.py @@ -0,0 +1,197 @@ +# Copyright 2025 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime as dt + +from absl.testing import absltest +from absl.testing import parameterized +from mmm.v1.common import date_interval_pb2 +from schema.utils import time_record +import numpy as np +import pandas as pd + +from google.type import date_pb2 +from tensorflow.python.util.protobuf import compare + + +class TimeRecordTest(parameterized.TestCase): + + def test_convert_times_to_date_intervals_fewer_than_two_times(self): + with self.assertRaisesRegex( + ValueError, + "There must be at least 2 time points.", + ): + time_record.convert_times_to_date_intervals( + times=["2024-01-01"], + ) + + def test_convert_times_to_date_intervals_enforces_iso_format(self): + with self.assertRaisesRegex( + ValueError, + "", + ): + time_record.convert_times_to_date_intervals( + times=["2024-01-01", "01-08-2024", "15-01-2024"], + ) + + def test_convert_times_to_date_interval_length_is_not_consistent(self): + with self.assertRaisesRegex( + ValueError, + "Interval length between selected times must be consistent.", + ): + time_record.convert_times_to_date_intervals( + times=["2024-01-01", "2024-01-07", "2024-01-15"], + ) + + def test_convert_times_to_date_intervals_creates_date_intervals(self): + time_to_date_interval = time_record.convert_times_to_date_intervals( + times=["2024-01-01", "2024-01-08", "2024-01-15"], + ) + expected_date_interval_1 = date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=1), + end_date=date_pb2.Date(year=2024, month=1, day=8), + ) + expected_date_interval_2 = date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=8), + end_date=date_pb2.Date(year=2024, month=1, day=15), + ) + expected_date_interval_3 = date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=15), + end_date=date_pb2.Date(year=2024, month=1, day=22), + ) + expected_time_to_date_interval = { + "2024-01-01": expected_date_interval_1, + "2024-01-08": expected_date_interval_2, + "2024-01-15": expected_date_interval_3, + } + self.assertEqual(time_to_date_interval, expected_time_to_date_interval) + + def test_convert_times_to_date_intervals_datetime_index_input(self): + time_to_date_interval = time_record.convert_times_to_date_intervals( + times=pd.DatetimeIndex([ + np.datetime64("2024-01-01"), + np.datetime64("2024-01-08"), + np.datetime64("2024-01-15"), + ]), + ) + expected_date_interval_1 = date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=1), + end_date=date_pb2.Date(year=2024, month=1, day=8), + ) + expected_date_interval_2 = date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=8), + end_date=date_pb2.Date(year=2024, month=1, day=15), + ) + expected_date_interval_3 = date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=15), + end_date=date_pb2.Date(year=2024, month=1, day=22), + ) + expected_time_to_date_interval = { + "2024-01-01": expected_date_interval_1, + "2024-01-08": expected_date_interval_2, + "2024-01-15": expected_date_interval_3, + } + self.assertEqual(time_to_date_interval, expected_time_to_date_interval) + + def test_convert_times_to_date_intervals_date_objects_input(self): + time_to_date_interval = time_record.convert_times_to_date_intervals( + times=[ + dt.date(2024, 1, 1), + dt.date(2024, 1, 8), + dt.date(2024, 1, 15), + ], + ) + expected_date_interval_1 = date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=1), + end_date=date_pb2.Date(year=2024, month=1, day=8), + ) + expected_date_interval_2 = date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=8), + end_date=date_pb2.Date(year=2024, month=1, day=15), + ) + expected_date_interval_3 = date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=15), + end_date=date_pb2.Date(year=2024, month=1, day=22), + ) + expected_time_to_date_interval = { + "2024-01-01": expected_date_interval_1, + "2024-01-08": expected_date_interval_2, + "2024-01-15": expected_date_interval_3, + } + self.assertEqual(time_to_date_interval, expected_time_to_date_interval) + + @parameterized.named_parameters( + dict( + testcase_name="single_day", + start_date=dt.datetime(year=2024, month=1, day=1), + end_date=dt.datetime(year=2024, month=1, day=1), + tag="", + expected=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=1), + end_date=date_pb2.Date(year=2024, month=1, day=1), + ), + ), + dict( + testcase_name="multiple_days", + start_date=dt.datetime(year=2024, month=1, day=1), + end_date=dt.datetime(year=2024, month=1, day=8), + tag="", + expected=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=1), + end_date=date_pb2.Date(year=2024, month=1, day=8), + ), + ), + dict( + testcase_name="single_day_with_tag", + start_date=dt.datetime(year=2024, month=1, day=1), + end_date=dt.datetime(year=2024, month=1, day=1), + tag="tag", + expected=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=1), + end_date=date_pb2.Date(year=2024, month=1, day=1), + tag="tag", + ), + ), + dict( + testcase_name="multiple_days_with_tag", + start_date=dt.datetime(year=2024, month=1, day=1), + end_date=dt.datetime(year=2024, month=1, day=8), + tag="tag", + expected=date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=1), + end_date=date_pb2.Date(year=2024, month=1, day=8), + tag="tag", + ), + ), + ) + def test_create_date_interval_pb(self, start_date, end_date, tag, expected): + actual = time_record.create_date_interval_pb( + start_date=start_date, end_date=end_date, tag=tag + ) + compare.assertProtoEqual(self, actual, expected) + + def test_dates_from_date_interval_proto(self): + date_interval = date_interval_pb2.DateInterval( + start_date=date_pb2.Date(year=2024, month=1, day=1), + end_date=date_pb2.Date(year=2025, month=2, day=8), + ) + start_date, end_date = time_record.dates_from_date_interval_proto( + date_interval + ) + self.assertEqual(start_date, dt.date(year=2024, month=1, day=1)) + self.assertEqual(end_date, dt.date(year=2025, month=2, day=8)) + + +if __name__ == "__main__": + absltest.main()