Skip to content

Mitigating treatment related bias in text embeddings using regressive residualization.

Notifications You must be signed in to change notification settings

torinriley/Causal-Embedding-Correction

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

43 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Techniques to Reduce Treatment Leakage in Text Embeddings

Problem Overview

Treatment Leakage

Treatment leakage arises when representations (e.g., text embeddings) contain information about the treatment variable. This can distort causal analysis, making it difficult to disentangle the causal effect of treatment from confounding factors.

Embedding-Based Models

Text embeddings, such as those generated by transformer models (e.g., BERT), provide high-dimensional vector representations of textual data. However, their high capacity may encode unintended correlations with treatment variables.


Methodology

1. Extracting Text Embeddings

We use the pretrained models BERT-base-uncased, RoBERTa-base, and DistilBERT-base-uncased to extract embeddings for textual data. The embeddings represent the semantic meaning of the input text in a high-dimensional space.

  • Step: The last_hidden_state of the model is averaged across tokens to produce a single embedding vector for each text instance.

2. High-Dimensional Treatments

To demonstrate the methodology, a synthetic high-dimensional treatment variable was generated with 10 independent features. These features simulate a realistic, complex treatment structure.


3. Embedding Decomposition Using Random Forest Regression

To isolate and remove treatment-related information from embeddings, we use Random Forest Regression.

Steps:

  1. Train Regressor:

    • A Random Forest model is trained with the high-dimensional treatment features as input and the original embeddings as the target.
    • The model captures nonlinear and high-dimensional relationships between the treatment and embeddings.
  2. Predict Treatment Components:

    • The trained model predicts the treatment-related components of the embeddings.
  3. Partial Residualization:

    • The predicted treatment components are scaled by a parameter (\alpha) and subtracted from the original embeddings.
    • This method balances treatment de-biasing and the retention of meaningful information.
  4. Propensity Scoring:

    • To further enhance causal validity and address confounding bias, we integrate propensity scores into the methodology.

      Steps:

    • A logistic regression model estimates the propensity scores, which represent the probability of receiving treatment given observed covariates. These scores are computed for each instance in the dataset based on the high-dimensional treatment features.

    • Evaluate Propensity Scores:

      • The AUC (Area Under the Curve) of the propensity model is calculated to validate its effectiveness.
      • High AUC values indicate that the model accurately captures covariate information.
    • The propensity scores are used to create inverse propensity weights. These weights adjust for covariate imbalances between treatment and control groups, mitigating confounding bias.

  5. Sensitivity Analysis:

    • To assess the robustness of treatment effect estimates to unobserved confounders, sensitivity analysis simulates potential unmeasured factors and evaluates their impact on propensity scores.
  6. Positivity Check:

    • The methodology ensures that the positivity assumption is not violated by checking for extreme propensity scores (close to 0 or 1). Observations with extreme scores are flagged to mitigate assumption violations.
  7. Balancing Diagnostics:

    • Standardized Mean Differences (SMD) before and after weighting are visualized to assess covariate balance between treatment groups.
  8. Uncertainty Quantification:

    • Bootstrapping is used to estimate confidence intervals for treatment effect estimates, providing robust uncertainty measures.

4. Validation and Visualizations

Correlation Analysis

We calculate the mean absolute correlation between the embedding dimensions and the treatment dimensions to assess treatment leakage:

  • Original Embeddings: Higher correlations with treatment.
  • Adjusted Embeddings: Reduced correlations, confirming the removal of treatment-related signals.

Variance Comparison

A comparison of total variance in the embeddings before and after partial residualization demonstrates the extent to which treatment-related variance is removed while preserving overall variability.

Embedding Scatter Plots

t-SNE scatter plots visualize the structure of embeddings:

  • Original Embeddings: Show strong clustering based on treatment, indicating leakage.
  • Adjusted Embeddings: Scatter with reduced clustering, indicating treatment-agnostic embeddings.

Distribution of Residuals

Histograms of residuals confirm the removal of treatment-related components, as residuals are tightly centered around zero.

Balancing Diagnostics

Bar plots of standardized mean differences (SMD) before and after weighting illustrate the improved covariate balance achieved through propensity score weighting.

Sensitivity Analysis

Simulations with unobserved confounders quantify the robustness of propensity score estimates to unmeasured variables.

Mutual Information Analysis

We measure the mutual information between embeddings and:

  • Treatment Features: Quantifies treatment leakage in original embeddings.
  • Outcome: Validates the extent to which embeddings capture predictive information about the outcome.

Results

  • Correlation Metrics:

    • Original Embeddings: Moderate correlation with treatment variables.
    • Adjusted Embeddings: Significantly reduced correlations, validating the effectiveness of partial residualization.
  • Variance Analysis:

    • The adjusted embeddings retain a meaningful proportion of variance while eliminating treatment-related components.
  • Visualizations:

    • Scatter plots of original and adjusted embeddings and the distribution of residuals provide qualitative and quantitative evidence of the methodology's success.
  • Positivity Check:

    • No extreme propensity scores detected, indicating that the positivity assumption holds.
  • Mutual Information Results:

    • Mutual Information with Treatment: (0.449)
    • Mutual Information with Outcome: (0.418)
  • Estimated Causal Effect:

    • Causal Effect Estimate: (0.0002) with 95% CI: ([0.0002, 0.0003])

Embedding Analysis Visualizations

Variance and Correlation Comparison

Variance Comparison Correlation Comparison
Screenshot 2024-12-29 at 12 05 05 AM Screenshot 2024-12-29 at 12 04 56 AM

Embedding Scatter Plots

Original Embeddings Adjusted Embeddings
Original Embeddings Screenshot 2024-12-28 at 11 40 40 PM

Conclusion

This project demonstrates a robust methodology to mitigate treatment leakage in text embeddings using the following:

  1. Random Forest Regression to model and remove treatment-related components.
  2. Partial Residualization to balance treatment de-biasing with the retention of meaningful data relationships.
  3. Propensity Score Integration to address confounding bias and improve covariate balance.
  4. Sensitivity Analysis and Diagnostics to validate assumptions and ensure robust causal inference.
  5. Uncertainty Quantification using bootstrapping to provide confidence intervals for treatment effect estimates.

By ensuring embeddings are treatment-agnostic but not treatment-blind, this approach enhances the reliability of causal inference models, enabling more accurate estimation of causal effects.

For complete results: Results Summary

See a disclaimer here: Disclaimer

About

Mitigating treatment related bias in text embeddings using regressive residualization.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages