diff --git a/tests/test_distribution_utils/test_censored_utils.py b/tests/test_distribution_utils/test_censored_utils.py new file mode 100644 index 00000000..30b8609a --- /dev/null +++ b/tests/test_distribution_utils/test_censored_utils.py @@ -0,0 +1,44 @@ +from ..utils import BaseTestClass +import pytest +import numpy as np +import torch +from xgboostlss.model import XGBoostLSS +from xgboostlss.distributions.LogNormal import CensoredLogNormal +from xgboostlss.distributions.Weibull import CensoredWeibull +from tests.utils import gen_test_data + +class TestCensoredUnivariate(BaseTestClass): + @pytest.fixture(params=[CensoredLogNormal, CensoredWeibull]) + def model(self, request): + return XGBoostLSS(request.param()) + + @pytest.mark.parametrize("weights", [False, True]) + def test_censored_objective_fn_shapes_and_values(self, model, weights): + predt, lower, upper, *rest = gen_test_data(model, weights=weights, censored=True) + dmat = rest[-1] + grad, hess = model.dist.objective_fn(predt, dmat) + assert isinstance(grad, np.ndarray) and isinstance(hess, np.ndarray) + assert grad.shape == predt.flatten().shape and hess.shape == predt.flatten().shape + assert not np.isnan(grad).any() and not np.isnan(hess).any() + assert not np.isinf(grad).any() and not np.isinf(hess).any() + + @pytest.mark.parametrize("loss_fn", ["nll", "crps"]) + @pytest.mark.parametrize("weights", [False, True]) + def test_censored_metric_fn_shapes_and_values(self, model, loss_fn, weights): + model.dist.loss_fn = loss_fn + predt, lower, upper, *rest = gen_test_data(model, weights=weights, censored=True) + dmat = rest[-1] + name, loss = model.dist.metric_fn(predt, dmat) + assert name == loss_fn and isinstance(loss, torch.Tensor) + assert not torch.isnan(loss).any() and not torch.isinf(loss).any() + + def test_metric_fn_exact_equals_uncensored(self, model): + predt, labels, *rest = gen_test_data(model, weights=False, censored=False) + dmat = rest[-1] + name_c, loss_c = model.dist.metric_fn(predt, dmat) + underlying_cls = model.dist.__class__.__mro__[2] + base_model = XGBoostLSS(underlying_cls()) + base_predt, base_labels, *base_rest = gen_test_data(base_model, weights=False, censored=False) + base_dmat = base_rest[-1] + name_b, loss_b = base_model.dist.metric_fn(base_predt, base_dmat) + assert name_c == name_b and torch.allclose(loss_c, loss_b) \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index e171ab17..7a32e1e3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,7 +10,7 @@ import xgboost as xgb -def gen_test_data(dist_class, weights: bool = False): +def gen_test_data(dist_class, weights: bool = False, censored: bool = False): """ Function that generates test data for a given distribution class. @@ -20,6 +20,8 @@ def gen_test_data(dist_class, weights: bool = False): Distribution class. weights (bool): Whether to generate weights. + censored (bool): + Whether to generate censored data. Returns: -------- @@ -36,6 +38,28 @@ def gen_test_data(dist_class, weights: bool = False): np.random.seed(123) predt = np.random.rand(dist_class.dist.n_dist_param * 4).reshape(-1, dist_class.dist.n_dist_param) labels = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1) + # Handle censored interval data + if censored: + # base values and censoring bounds + labels = labels.flatten() + lower = labels - 0.1 + upper = labels + 0.1 + if weights: + weights_arr = np.ones_like(lower, dtype=lower.dtype) + dmatrix = xgb.DMatrix(predt, + label_lower_bound=lower, + label_upper_bound=upper, + weight=weights_arr) + dist_class.set_base_margin(dmatrix) + return predt, lower, upper, weights_arr, dmatrix + else: + dmatrix = xgb.DMatrix(predt, + label_lower_bound=lower, + label_upper_bound=upper) + dist_class.set_base_margin(dmatrix) + return predt, lower, upper, dmatrix + # uncensored data + if weights: weights = np.ones_like(labels) dmatrix = xgb.DMatrix(predt, label=labels, weight=weights) @@ -48,6 +72,7 @@ def gen_test_data(dist_class, weights: bool = False): return predt, labels, dmatrix else: + # multivariate (censored not supported) np.random.seed(123) predt = np.random.rand(dist_class.dist.n_dist_param * 4).reshape(-1, dist_class.dist.n_dist_param) labels = np.arange(0.1, 0.9, 0.1) @@ -56,7 +81,7 @@ def gen_test_data(dist_class, weights: bool = False): dist_class.dist.n_targets, dist_class.dist.n_dist_param ) - if weights: + if weights and not censored: weights = np.ones_like(labels[:, 0], dtype=labels.dtype).reshape(-1, 1) dmatrix = xgb.DMatrix(predt, label=labels, weight=weights) dist_class.set_base_margin(dmatrix) diff --git a/xgboostlss/distributions/LogNormal.py b/xgboostlss/distributions/LogNormal.py index b00eb694..907b7de8 100644 --- a/xgboostlss/distributions/LogNormal.py +++ b/xgboostlss/distributions/LogNormal.py @@ -1,5 +1,6 @@ from torch.distributions import LogNormal as LogNormal_Torch from .distribution_utils import DistributionClass +from .censored_utils import CensoredMixin from ..utils import * @@ -65,3 +66,8 @@ def __init__(self, distribution_arg_names=list(param_dict.keys()), loss_fn=loss_fn ) + + +class CensoredLogNormal(CensoredMixin, LogNormal): + """LogNormal distribution with interval-censoring support.""" + pass diff --git a/xgboostlss/distributions/Weibull.py b/xgboostlss/distributions/Weibull.py index 08006333..2edc870e 100644 --- a/xgboostlss/distributions/Weibull.py +++ b/xgboostlss/distributions/Weibull.py @@ -1,5 +1,6 @@ from torch.distributions import Weibull as Weibull_Torch from .distribution_utils import DistributionClass +from .censored_utils import CensoredMixin from ..utils import * @@ -65,3 +66,9 @@ def __init__(self, distribution_arg_names=list(param_dict.keys()), loss_fn=loss_fn ) + +class CensoredWeibull(CensoredMixin, Weibull): + """ + Weibull distribution class with interval censoring support. + """ + pass \ No newline at end of file diff --git a/xgboostlss/distributions/censored_utils.py b/xgboostlss/distributions/censored_utils.py new file mode 100644 index 00000000..89bb6e14 --- /dev/null +++ b/xgboostlss/distributions/censored_utils.py @@ -0,0 +1,74 @@ +import numpy as np +import torch +import xgboost as xgb +from typing import List, Tuple +from .distribution_utils import DistributionClass + + +class CensoredMixin(DistributionClass): + """ + Mixin to add interval-censoring support to a distribution. + Overrides objective_fn and metric_fn to dispatch to censored loss. + """ + def objective_fn(self, predt: np.ndarray, data: xgb.DMatrix): + lower = data.get_float_info("label_lower_bound") + upper = data.get_float_info("label_upper_bound") + if lower.size == 0 and upper.size == 0: + return super().objective_fn(predt, data) + if data.get_weight().size == 0: + # initialize weights as ones with correct shape + weights = torch.ones((lower.shape[0], 1), dtype=torch.as_tensor(lower).dtype).numpy() + else: + weights = data.get_weight().reshape(-1, 1) + start_values = data.get_base_margin().reshape(-1, self.n_dist_param)[0, :].tolist() + predt_list, loss = self.get_params_loss_censored( + predt, start_values, lower, upper, requires_grad=True + ) + grad, hess = self.compute_gradients_and_hessians(loss, predt_list, weights) + return grad, hess + + def metric_fn(self, predt: np.ndarray, data: xgb.DMatrix): + lower = data.get_float_info("label_lower_bound") + upper = data.get_float_info("label_upper_bound") + if lower.size == 0 and upper.size == 0: + return super().metric_fn(predt, data) + start_values = data.get_base_margin().reshape(-1, self.n_dist_param)[0, :].tolist() + _, loss = self.get_params_loss_censored( + predt, start_values, lower, upper, requires_grad=False + ) + return self.loss_fn, loss + + def get_params_loss_censored(self, + predt: np.ndarray, + start_values: List[float], + lower: np.ndarray, + upper: np.ndarray, + requires_grad: bool = False, + ) -> Tuple[List[torch.Tensor], torch.Tensor]: + """Compute loss for interval-censored data.""" + predt_arr = predt.reshape(-1, self.n_dist_param) + # replace nan/inf + mask = np.isnan(predt_arr) | np.isinf(predt_arr) + predt_arr[mask] = np.take(start_values, np.where(mask)[1]) + # convert to tensors + predt_list = [ + torch.tensor(predt_arr[:, i].reshape(-1, 1), requires_grad=requires_grad) + for i in range(self.n_dist_param) + ] + # transform parameters + params_transformed = [ + fn(predt_list[i]) for i, fn in enumerate(self.param_dict.values()) + ] + # instantiate distribution + dist = self.distribution(**dict(zip(self.distribution_arg_names, params_transformed))) + # compute cdf bounds: convert lower & upper once to tensor with correct dtype + low = torch.as_tensor(lower, dtype=params_transformed[0].dtype).reshape(-1, 1) + hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1) + cdf_low = dist.cdf(low) + cdf_hi = dist.cdf(hi) + # interval mass & loss + mass = cdf_hi - cdf_low + log_density = dist.log_prob(low) + censored_inds = low != hi + loss = -torch.sum(torch.log(mass[censored_inds])) - torch.sum(log_density[~censored_inds]) + return predt_list, loss