diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index 37204052..ec92511b 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -19,8 +19,8 @@ import numpy as np import pandas as pd import seaborn as sns +from formulae import design_matrices from matplotlib import pyplot as plt -from patsy import build_design_matrices, dmatrices from sklearn.base import RegressorMixin from causalpy.custom_exceptions import ( @@ -91,16 +91,18 @@ def __init__( self.data = data self.expt_type = "Difference in Differences" self.formula = formula + self.rhs_formula = formula.split("~", 1)[1].strip() self.time_variable_name = time_variable_name self.group_variable_name = group_variable_name self.input_validation() - y, X = dmatrices(formula, self.data) - self._y_design_info = y.design_info - self._x_design_info = X.design_info - self.labels = X.design_info.column_names - self.y, self.X = np.asarray(y), np.asarray(X) - self.outcome_variable_name = y.design_info.column_names[0] + dm = design_matrices(self.formula, self.data) + self.labels = list(dm.common.terms.keys()) + self.y, self.X = ( + np.asarray(dm.response.design_matrix).reshape(-1, 1), + np.asarray(dm.common.design_matrix), + ) + self.outcome_variable_name = dm.response.name # fit model if isinstance(self.model, PyMCModel): @@ -125,8 +127,8 @@ def __init__( ) if self.x_pred_control.empty: raise ValueError("x_pred_control is empty") - (new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control) - self.y_pred_control = self.model.predict(np.asarray(new_x)) + new_x = np.array(design_matrices(self.rhs_formula, self.x_pred_control).common) + self.y_pred_control = self.model.predict(new_x) # predicted outcome for treatment group self.x_pred_treatment = ( @@ -142,8 +144,10 @@ def __init__( ) if self.x_pred_treatment.empty: raise ValueError("x_pred_treatment is empty") - (new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment) - self.y_pred_treatment = self.model.predict(np.asarray(new_x)) + new_x = np.array( + design_matrices(self.rhs_formula, self.x_pred_treatment).common + ) + self.y_pred_treatment = self.model.predict(new_x) # predicted outcome for counterfactual. This is given by removing the influence # of the interaction term between the group and the post_treatment variable @@ -162,15 +166,15 @@ def __init__( ) if self.x_pred_counterfactual.empty: raise ValueError("x_pred_counterfactual is empty") - (new_x,) = build_design_matrices( - [self._x_design_info], self.x_pred_counterfactual, return_type="dataframe" + new_x = np.array( + design_matrices(self.rhs_formula, self.x_pred_counterfactual).common ) # INTERVENTION: set the interaction term between the group and the # post_treatment variable to zero. This is the counterfactual. for i, label in enumerate(self.labels): if "post_treatment" in label and self.group_variable_name in label: - new_x.iloc[:, i] = 0 - self.y_pred_counterfactual = self.model.predict(np.asarray(new_x)) + new_x[:, i] = 0 + self.y_pred_counterfactual = self.model.predict(new_x) # calculate causal impact if isinstance(self.model, PyMCModel): diff --git a/pyproject.toml b/pyproject.toml index 99c4a651..fba5f6a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "numpy", "pandas", "patsy", + "formulae", "pymc>=5.15.1", "scikit-learn>=1", "scipy",