Skip to content

Commit 3d3f482

Browse files
authored
Merge pull request #506 from pymc-labs/regression-discontinuity-treatment-type
Handle integer treated variable in `RegressionDiscontinuity` experiment class
2 parents d4ba5be + e933420 commit 3d3f482

File tree

3 files changed

+58
-3
lines changed

3 files changed

+58
-3
lines changed

causalpy/experiments/regression_discontinuity.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ def input_validation(self):
210210
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
211211
)
212212

213+
# Convert integer treated variable to boolean if needed
214+
if self.data["treated"].dtype in ["int64", "int32"]:
215+
# Make a copy to avoid SettingWithCopyWarning
216+
self.data = self.data.copy()
217+
self.data["treated"] = self.data["treated"].astype(bool)
218+
213219
def _is_treated(self, x):
214220
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.
215221

causalpy/tests/test_input_validation.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,3 +383,52 @@ def test_rkink_epsilon_check():
383383
kink_point=kink,
384384
epsilon=-1,
385385
)
386+
387+
388+
# RegressionDiscontinuity
389+
390+
391+
def setup_regression_discontinuity_data(threshold=0.5):
392+
"""Create data for a regression discontinuity test."""
393+
np.random.seed(42)
394+
x = np.random.uniform(0, 1, 100)
395+
treated = np.where(x > threshold, 1, 0)
396+
y = 2 * x + treated + np.random.normal(0, 1, 100)
397+
return pd.DataFrame({"x": x, "treated": treated, "y": y})
398+
399+
400+
def test_regression_discontinuity_int_treatment():
401+
"""Test that RegressionDiscontinuity works with integer treatment variables."""
402+
threshold = 0.5
403+
df = setup_regression_discontinuity_data(threshold)
404+
assert df["treated"].dtype == np.int64 # Ensure treatment is int
405+
406+
# This should work now with our fix
407+
result = cp.RegressionDiscontinuity(
408+
df,
409+
formula="y ~ 1 + x + treated + x:treated",
410+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
411+
treatment_threshold=threshold,
412+
)
413+
414+
# Check that the treatment variable was converted to bool
415+
assert result.data["treated"].dtype == bool
416+
417+
418+
def test_regression_discontinuity_bool_treatment():
419+
"""Test that RegressionDiscontinuity works with boolean treatment variables."""
420+
threshold = 0.5
421+
df = setup_regression_discontinuity_data(threshold)
422+
df["treated"] = df["treated"].astype(bool) # Convert to bool
423+
assert df["treated"].dtype == bool # Ensure treatment is bool
424+
425+
# This should work as before
426+
result = cp.RegressionDiscontinuity(
427+
df,
428+
formula="y ~ 1 + x + treated + x:treated",
429+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
430+
treatment_threshold=threshold,
431+
)
432+
433+
# Check that the treatment variable is still bool
434+
assert result.data["treated"].dtype == bool

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)