diff --git a/causalpy/experiments/regression_discontinuity.py b/causalpy/experiments/regression_discontinuity.py index 1afc5c1a..ebf3b103 100644 --- a/causalpy/experiments/regression_discontinuity.py +++ b/causalpy/experiments/regression_discontinuity.py @@ -190,6 +190,8 @@ def input_validation(self): raise DataException( """The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501 ) + if not self.data['treated'].dtype == 'bool': + raise ValueError("The 'treated' column must be of type bool.Please convert your data accordingly.") def _is_treated(self, x): """Returns ``True`` if `x` is greater than or equal to the treatment threshold. diff --git a/causalpy/experiments/test_treated_column_valid.py b/causalpy/experiments/test_treated_column_valid.py new file mode 100644 index 00000000..99397e1e --- /dev/null +++ b/causalpy/experiments/test_treated_column_valid.py @@ -0,0 +1,20 @@ +import pandas as pd +import pytest + +def _check_treated_column_validity(df, treated_col_name): + treated_col = df[treated_col_name] + if not pd.api.types.is_bool_dtype(treated_col): + raise ValueError(f"The '{treated_col_name}' column must be of boolean dtype (True/False).") + +def test_treated_column_with_integers(): + df = pd.DataFrame({"treated": [0, 1, 0, 1]}) + with pytest.raises(ValueError, match="treated.*must be of boolean dtype"): + _check_treated_column_validity(df, "treated") + +def test_treated_column_with_booleans(): + df = pd.DataFrame({"treated": [True, False, True, False]}) + try: + _check_treated_column_validity(df, "treated") + except ValueError: + pytest.fail("Unexpected ValueError raised") +