Skip to content
12 changes: 6 additions & 6 deletions metalearners/cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,12 @@ def _predict_in_sample(
) -> np.ndarray:
if not self._test_indices:
raise ValueError()
if len(X) != sum(len(fold) for fold in self._test_indices):
raise ValueError(
"Trying to predict in-sample on data that is unlike data encountered in training. "
f"Training data included {sum(len(fold) for fold in self._test_indices)} "
f"observations while prediction data includes {len(X)} observations."
)
# if len(X) != sum(len(fold) for fold in self._test_indices):
# raise ValueError(
# "Trying to predict in-sample on data that is unlike data encountered in training. "
# f"Training data included {sum(len(fold) for fold in self._test_indices)} "
# f"observations while prediction data includes {len(X)} observations."
# )
n_outputs = self._n_outputs(method)
predictions = self._initialize_prediction_tensor(
n_observations=len(X),
Expand Down
137 changes: 93 additions & 44 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,31 +99,39 @@ def fit_all_nuisance(

qualified_fit_params = self._qualified_fit_params(fit_params)

self._cvs: list = []
# TODO: Move this to object initialization.
if not synchronize_cross_fitting:
raise ValueError(
"The X-Learner does not support synchronize_cross_fitting=False."
)

self._cv_split_indices = self._split(X)
self._treatment_cv_split_indices = {}

for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)
if synchronize_cross_fitting:
cv_split_indices = self._split(
index_matrix(X, self._treatment_variants_indices[treatment_variant])
treatment_indices = np.where(
Copy link
Collaborator Author

@kklein kklein Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an opaque way of turning an array [True, True, False, False, True] into an array [0, 1, 4]. Not sure if there's a neater way of doing that.

Copy link
Contributor

@MatthiasLoefflerQC MatthiasLoefflerQC Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[index for index, value in enumerate(vector) if value] would work too, I guess, and is more verbose, but I like the np.where :)

self._treatment_variants_indices[treatment_variant]
)[0]
self._treatment_cv_split_indices[treatment_variant] = [
(
np.intersect1d(train_indices, treatment_indices),
np.intersect1d(test_indices, treatment_indices),
)
else:
cv_split_indices = None
self._cvs.append(cv_split_indices)
for train_indices, test_indices in self._cv_split_indices
]

nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
for treatment_variant in range(self.n_variants):
nuisance_jobs.append(
self._nuisance_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
y=y[self._treatment_variants_indices[treatment_variant]],
X=X,
y=y,
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=treatment_variant,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL],
cv=self._cvs[treatment_variant],
cv=self._treatment_cv_split_indices[treatment_variant],
)
)

Expand Down Expand Up @@ -160,14 +168,14 @@ def fit_all_treatment(
) -> Self:
if self._treatment_variants_indices is None:
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"The nuisance models need to be fitted before fitting the treatment models. "
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"typically set during nuisance fitting, is None."
)
if not hasattr(self, "_cvs"):
if not hasattr(self, "_treatment_cv_split_indices"):
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"In particular, the MetaLearner's attribute _cvs, "
"The nuisance models need to be fitted before fitting the treatment models. "
"In particular, the MetaLearner's attribute _treatment_cv_split_indices, "
"typically set during nuisance fitting, does not exist."
)
qualified_fit_params = self._qualified_fit_params(fit_params)
Expand All @@ -180,34 +188,32 @@ def fit_all_treatment(
is_oos=False,
)
)

for treatment_variant in range(1, self.n_variants):
imputed_te_control, imputed_te_treatment = self._pseudo_outcome(
y, w, treatment_variant, conditional_average_outcome_estimates
)

treatment_jobs.append(
self._treatment_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
X=X,
y=imputed_te_treatment,
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL],
cv=self._cvs[treatment_variant],
cv=self._treatment_cv_split_indices[treatment_variant],
)
)

treatment_jobs.append(
self._treatment_joblib_specifications(
X=index_matrix(X, self._treatment_variants_indices[0]),
X=X,
y=imputed_te_control,
model_kind=CONTROL_EFFECT_MODEL,
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL],
cv=self._cvs[0],
cv=self._treatment_cv_split_indices[0],
)
)

Expand All @@ -216,6 +222,7 @@ def fit_all_treatment(
delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
)
self._assign_joblib_treatment_results(results)

return self

def predict(
Expand Down Expand Up @@ -278,19 +285,18 @@ def predict(
oos_method=oos_method,
)
)

tau_hat_treatment[treatment_variant_indices] = self.predict_treatment(
X=index_matrix(X, treatment_variant_indices),
X=X,
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=False,
)
)[treatment_variant_indices]
tau_hat_control[control_indices] = self.predict_treatment(
X=index_matrix(X, control_indices),
X=X,
model_kind=CONTROL_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=False,
)
)[control_indices]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need is_oos=False below (and likewise for tau_hat_treatment)? Might be worth a try.

tau_hat_control[non_control_indices] = self.predict_treatment(
X=index_matrix(X, non_control_indices),
model_kind=CONTROL_EFFECT_MODEL,
Expand Down Expand Up @@ -337,8 +343,8 @@ def evaluate(

variant_outcome_evaluation = _evaluate_model_kind(
cfes=self._nuisance_models[VARIANT_OUTCOME_MODEL],
Xs=[X[w == tv] for tv in range(self.n_variants)],
ys=[y[w == tv] for tv in range(self.n_variants)],
Xs=[X] * self.n_variants,
ys=[y] * self.n_variants,
scorers=safe_scoring[VARIANT_OUTCOME_MODEL],
model_kind=VARIANT_OUTCOME_MODEL,
is_oos=is_oos,
Expand Down Expand Up @@ -378,7 +384,7 @@ def evaluate(

te_treatment_evaluation = _evaluate_model_kind(
self._treatment_models[TREATMENT_EFFECT_MODEL],
Xs=[X[w == tv] for tv in range(1, self.n_variants)],
Xs=[X] * self.n_variants,
ys=imputed_te_treatment,
scorers=safe_scoring[TREATMENT_EFFECT_MODEL],
model_kind=TREATMENT_EFFECT_MODEL,
Expand All @@ -390,7 +396,7 @@ def evaluate(

te_control_evaluation = _evaluate_model_kind(
self._treatment_models[CONTROL_EFFECT_MODEL],
Xs=[X[w == 0] for _ in range(1, self.n_variants)],
Xs=[X] * self.n_variants,
ys=imputed_te_control,
scorers=safe_scoring[CONTROL_EFFECT_MODEL],
model_kind=CONTROL_EFFECT_MODEL,
Expand Down Expand Up @@ -424,16 +430,8 @@ def _pseudo_outcome(
This function can be used with both in-sample or out-of-sample data.
"""
validate_valid_treatment_variant_not_control(treatment_variant, self.n_variants)

treatment_indices = w == treatment_variant
control_indices = w == 0

treatment_outcome = index_matrix(
conditional_average_outcome_estimates, control_indices
)[:, treatment_variant]
control_outcome = index_matrix(
conditional_average_outcome_estimates, treatment_indices
)[:, 0]
treatment_outcome = conditional_average_outcome_estimates[:, treatment_variant]
control_outcome = conditional_average_outcome_estimates[:, 0]

if self.is_classification:
# Get the probability of positive class, multiclass is currently not supported.
Expand All @@ -443,8 +441,8 @@ def _pseudo_outcome(
control_outcome = control_outcome[:, 0]
treatment_outcome = treatment_outcome[:, 0]

imputed_te_treatment = y[treatment_indices] - control_outcome
imputed_te_control = treatment_outcome - y[control_indices]
imputed_te_treatment = y - control_outcome
imputed_te_control = treatment_outcome - y

return imputed_te_control, imputed_te_treatment

Expand Down Expand Up @@ -534,3 +532,54 @@ def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
final_model = build(input_dict, {output_name: cate})
check_model(final_model, full_check=True)
return final_model

def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
) -> np.ndarray:
if self._treatment_variants_indices is None:
raise ValueError(
"The metalearner needs to be fitted before predicting."
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"typically set during fitting, is None."
)
# TODO: Consider multiprocessing
n_obs = len(X)
cao_tensor = self._nuisance_tensors(n_obs)[VARIANT_OUTCOME_MODEL][0]
predict_method_name = self.nuisance_model_specifications()[
VARIANT_OUTCOME_MODEL
]["predict_method"](self)
conditional_average_outcomes_list = []

for tv in range(self.n_variants):
if is_oos:
conditional_average_outcomes_list.append(
self.predict_nuisance(
X=X,
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=tv,
is_oos=True,
oos_method=oos_method,
)
)
else:
# TODO: Consider moving this logic to CrossFitEstimator.predict.
cfe = self._nuisance_models[VARIANT_OUTCOME_MODEL][tv]
conditional_average_outcome_estimates = cao_tensor.copy()

for fold_index, (train_indices, prediction_indices) in enumerate(
self._cv_split_indices
):
fold_model = cfe._estimators[fold_index]
predict_method = getattr(fold_model, predict_method_name)
fold_estimates = predict_method(index_matrix(X, prediction_indices))
conditional_average_outcome_estimates[prediction_indices] = (
fold_estimates
)

conditional_average_outcomes_list.append(
conditional_average_outcome_estimates
)

return np.stack(conditional_average_outcomes_list, axis=1).reshape(
n_obs, self.n_variants, -1
)
14 changes: 11 additions & 3 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,9 +727,17 @@ def test_fit_params(metalearner_factory, fit_params, expected_keys, dummy_datase
is_classification=False,
n_folds=1,
)
# Using cross-fitting is not possible with a single fold.
if metalearner_factory == XLearner:
# TODO: The X-Learner doesn't support using synchronize_cross_fitting=False.
# As a consequence, it doesn't support n_folds=1 either.
# We should find an alternative to testing this property for the X-Learner.
pytest.skip()
metalearner.fit(
X=X, y=y, w=w, fit_params=fit_params, synchronize_cross_fitting=False
X=X,
y=y,
w=w,
fit_params=fit_params,
synchronize_cross_fitting=False,
)


Expand Down Expand Up @@ -994,9 +1002,9 @@ def test_shap_values_smoke(
[
TLearner,
SLearner,
XLearner,
RLearner,
DRLearner,
# The X-Learner does not support synchronze_cross_fitting = False.
],
)
@pytest.mark.parametrize("n_variants", [2, 5])
Expand Down