@@ -383,3 +383,52 @@ def test_rkink_epsilon_check():
383
383
kink_point = kink ,
384
384
epsilon = - 1 ,
385
385
)
386
+
387
+
388
+ # RegressionDiscontinuity
389
+
390
+
391
+ def setup_regression_discontinuity_data (threshold = 0.5 ):
392
+ """Create data for a regression discontinuity test."""
393
+ np .random .seed (42 )
394
+ x = np .random .uniform (0 , 1 , 100 )
395
+ treated = np .where (x > threshold , 1 , 0 )
396
+ y = 2 * x + treated + np .random .normal (0 , 1 , 100 )
397
+ return pd .DataFrame ({"x" : x , "treated" : treated , "y" : y })
398
+
399
+
400
+ def test_regression_discontinuity_int_treatment ():
401
+ """Test that RegressionDiscontinuity works with integer treatment variables."""
402
+ threshold = 0.5
403
+ df = setup_regression_discontinuity_data (threshold )
404
+ assert df ["treated" ].dtype == np .int64 # Ensure treatment is int
405
+
406
+ # This should work now with our fix
407
+ result = cp .RegressionDiscontinuity (
408
+ df ,
409
+ formula = "y ~ 1 + x + treated + x:treated" ,
410
+ model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
411
+ treatment_threshold = threshold ,
412
+ )
413
+
414
+ # Check that the treatment variable was converted to bool
415
+ assert result .data ["treated" ].dtype == bool
416
+
417
+
418
+ def test_regression_discontinuity_bool_treatment ():
419
+ """Test that RegressionDiscontinuity works with boolean treatment variables."""
420
+ threshold = 0.5
421
+ df = setup_regression_discontinuity_data (threshold )
422
+ df ["treated" ] = df ["treated" ].astype (bool ) # Convert to bool
423
+ assert df ["treated" ].dtype == bool # Ensure treatment is bool
424
+
425
+ # This should work as before
426
+ result = cp .RegressionDiscontinuity (
427
+ df ,
428
+ formula = "y ~ 1 + x + treated + x:treated" ,
429
+ model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
430
+ treatment_threshold = threshold ,
431
+ )
432
+
433
+ # Check that the treatment variable is still bool
434
+ assert result .data ["treated" ].dtype == bool
0 commit comments