Treatment leakage arises when representations (e.g., text embeddings) contain information about the treatment variable. This can distort causal analysis, making it difficult to disentangle the causal effect of treatment from confounding factors.
Text embeddings, such as those generated by transformer models (e.g., BERT), provide high-dimensional vector representations of textual data. However, their high capacity may encode unintended correlations with treatment variables.
We use the pretrained models BERT-base-uncased, RoBERTa-base, and DistilBERT-base-uncased to extract embeddings for textual data. The embeddings represent the semantic meaning of the input text in a high-dimensional space.
- Step: The
last_hidden_state
of the model is averaged across tokens to produce a single embedding vector for each text instance.
To demonstrate the methodology, a synthetic high-dimensional treatment variable was generated with 10 independent features. These features simulate a realistic, complex treatment structure.
To isolate and remove treatment-related information from embeddings, we use Random Forest Regression.
-
Train Regressor:
- A Random Forest model is trained with the high-dimensional treatment features as input and the original embeddings as the target.
- The model captures nonlinear and high-dimensional relationships between the treatment and embeddings.
-
Predict Treatment Components:
- The trained model predicts the treatment-related components of the embeddings.
-
Partial Residualization:
- The predicted treatment components are scaled by a parameter (\alpha) and subtracted from the original embeddings.
- This method balances treatment de-biasing and the retention of meaningful information.
-
Propensity Scoring:
-
To further enhance causal validity and address confounding bias, we integrate propensity scores into the methodology.
Steps:
-
A logistic regression model estimates the propensity scores, which represent the probability of receiving treatment given observed covariates. These scores are computed for each instance in the dataset based on the high-dimensional treatment features.
-
Evaluate Propensity Scores:
- The AUC (Area Under the Curve) of the propensity model is calculated to validate its effectiveness.
- High AUC values indicate that the model accurately captures covariate information.
-
The propensity scores are used to create inverse propensity weights. These weights adjust for covariate imbalances between treatment and control groups, mitigating confounding bias.
-
-
Sensitivity Analysis:
- To assess the robustness of treatment effect estimates to unobserved confounders, sensitivity analysis simulates potential unmeasured factors and evaluates their impact on propensity scores.
-
Positivity Check:
- The methodology ensures that the positivity assumption is not violated by checking for extreme propensity scores (close to 0 or 1). Observations with extreme scores are flagged to mitigate assumption violations.
-
Balancing Diagnostics:
- Standardized Mean Differences (SMD) before and after weighting are visualized to assess covariate balance between treatment groups.
-
Uncertainty Quantification:
- Bootstrapping is used to estimate confidence intervals for treatment effect estimates, providing robust uncertainty measures.
We calculate the mean absolute correlation between the embedding dimensions and the treatment dimensions to assess treatment leakage:
- Original Embeddings: Higher correlations with treatment.
- Adjusted Embeddings: Reduced correlations, confirming the removal of treatment-related signals.
A comparison of total variance in the embeddings before and after partial residualization demonstrates the extent to which treatment-related variance is removed while preserving overall variability.
t-SNE
scatter plots visualize the structure of embeddings:
- Original Embeddings: Show strong clustering based on treatment, indicating leakage.
- Adjusted Embeddings: Scatter with reduced clustering, indicating treatment-agnostic embeddings.
Histograms of residuals confirm the removal of treatment-related components, as residuals are tightly centered around zero.
Bar plots of standardized mean differences (SMD) before and after weighting illustrate the improved covariate balance achieved through propensity score weighting.
Simulations with unobserved confounders quantify the robustness of propensity score estimates to unmeasured variables.
We measure the mutual information between embeddings and:
- Treatment Features: Quantifies treatment leakage in original embeddings.
- Outcome: Validates the extent to which embeddings capture predictive information about the outcome.
-
Correlation Metrics:
- Original Embeddings: Moderate correlation with treatment variables.
- Adjusted Embeddings: Significantly reduced correlations, validating the effectiveness of partial residualization.
-
Variance Analysis:
- The adjusted embeddings retain a meaningful proportion of variance while eliminating treatment-related components.
-
Visualizations:
- Scatter plots of original and adjusted embeddings and the distribution of residuals provide qualitative and quantitative evidence of the methodology's success.
-
Positivity Check:
- No extreme propensity scores detected, indicating that the positivity assumption holds.
-
Mutual Information Results:
- Mutual Information with Treatment: (0.449)
- Mutual Information with Outcome: (0.418)
-
Estimated Causal Effect:
- Causal Effect Estimate: (0.0002) with 95% CI: ([0.0002, 0.0003])
Variance Comparison | Correlation Comparison |
---|---|
Original Embeddings | Adjusted Embeddings |
---|---|
This project demonstrates a robust methodology to mitigate treatment leakage in text embeddings using the following:
- Random Forest Regression to model and remove treatment-related components.
- Partial Residualization to balance treatment de-biasing with the retention of meaningful data relationships.
- Propensity Score Integration to address confounding bias and improve covariate balance.
- Sensitivity Analysis and Diagnostics to validate assumptions and ensure robust causal inference.
- Uncertainty Quantification using bootstrapping to provide confidence intervals for treatment effect estimates.
By ensuring embeddings are treatment-agnostic but not treatment-blind, this approach enhances the reliability of causal inference models, enabling more accurate estimation of causal effects.
For complete results: Results Summary
See a disclaimer here: Disclaimer