Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tests/test_distribution_utils/test_censored_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from ..utils import BaseTestClass
import pytest
import numpy as np
import torch
from xgboostlss.model import XGBoostLSS
from xgboostlss.distributions.LogNormal import CensoredLogNormal
from xgboostlss.distributions.Weibull import CensoredWeibull
from tests.utils import gen_test_data

class TestCensoredUnivariate(BaseTestClass):
@pytest.fixture(params=[CensoredLogNormal, CensoredWeibull])
def model(self, request):
return XGBoostLSS(request.param())

@pytest.mark.parametrize("weights", [False, True])
def test_censored_objective_fn_shapes_and_values(self, model, weights):
predt, lower, upper, *rest = gen_test_data(model, weights=weights, censored=True)
dmat = rest[-1]
grad, hess = model.dist.objective_fn(predt, dmat)
assert isinstance(grad, np.ndarray) and isinstance(hess, np.ndarray)
assert grad.shape == predt.flatten().shape and hess.shape == predt.flatten().shape
assert not np.isnan(grad).any() and not np.isnan(hess).any()
assert not np.isinf(grad).any() and not np.isinf(hess).any()

@pytest.mark.parametrize("loss_fn", ["nll", "crps"])
@pytest.mark.parametrize("weights", [False, True])
def test_censored_metric_fn_shapes_and_values(self, model, loss_fn, weights):
model.dist.loss_fn = loss_fn
predt, lower, upper, *rest = gen_test_data(model, weights=weights, censored=True)
dmat = rest[-1]
name, loss = model.dist.metric_fn(predt, dmat)
assert name == loss_fn and isinstance(loss, torch.Tensor)
assert not torch.isnan(loss).any() and not torch.isinf(loss).any()

def test_metric_fn_exact_equals_uncensored(self, model):
predt, labels, *rest = gen_test_data(model, weights=False, censored=False)
dmat = rest[-1]
name_c, loss_c = model.dist.metric_fn(predt, dmat)
underlying_cls = model.dist.__class__.__mro__[2]
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using hardcoded index [2] in the MRO (Method Resolution Order) is fragile and could break if the inheritance hierarchy changes. Consider using a more explicit approach like checking class names or using hasattr to find the base distribution class.

Suggested change
underlying_cls = model.dist.__class__.__mro__[2]
# Find the first base class in the MRO that is not a censored distribution and not 'object'
underlying_cls = next(
cls for cls in model.dist.__class__.__mro__
if cls is not model.dist.__class__ and not cls.__name__.startswith("Censored") and cls is not object
)

Copilot uses AI. Check for mistakes.
base_model = XGBoostLSS(underlying_cls())
base_predt, base_labels, *base_rest = gen_test_data(base_model, weights=False, censored=False)
base_dmat = base_rest[-1]
name_b, loss_b = base_model.dist.metric_fn(base_predt, base_dmat)
assert name_c == name_b and torch.allclose(loss_c, loss_b)
29 changes: 27 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import xgboost as xgb


def gen_test_data(dist_class, weights: bool = False):
def gen_test_data(dist_class, weights: bool = False, censored: bool = False):
"""
Function that generates test data for a given distribution class.

Expand All @@ -20,6 +20,8 @@ def gen_test_data(dist_class, weights: bool = False):
Distribution class.
weights (bool):
Whether to generate weights.
censored (bool):
Whether to generate censored data.

Returns:
--------
Expand All @@ -36,6 +38,28 @@ def gen_test_data(dist_class, weights: bool = False):
np.random.seed(123)
predt = np.random.rand(dist_class.dist.n_dist_param * 4).reshape(-1, dist_class.dist.n_dist_param)
labels = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1)
# Handle censored interval data
if censored:
# base values and censoring bounds
labels = labels.flatten()
lower = labels - 0.1
upper = labels + 0.1
if weights:
weights_arr = np.ones_like(lower, dtype=lower.dtype)
dmatrix = xgb.DMatrix(predt,
label_lower_bound=lower,
label_upper_bound=upper,
weight=weights_arr)
dist_class.set_base_margin(dmatrix)
return predt, lower, upper, weights_arr, dmatrix
else:
dmatrix = xgb.DMatrix(predt,
label_lower_bound=lower,
label_upper_bound=upper)
dist_class.set_base_margin(dmatrix)
return predt, lower, upper, dmatrix
# uncensored data

if weights:
weights = np.ones_like(labels)
dmatrix = xgb.DMatrix(predt, label=labels, weight=weights)
Expand All @@ -48,6 +72,7 @@ def gen_test_data(dist_class, weights: bool = False):

return predt, labels, dmatrix
else:
# multivariate (censored not supported)
np.random.seed(123)
predt = np.random.rand(dist_class.dist.n_dist_param * 4).reshape(-1, dist_class.dist.n_dist_param)
labels = np.arange(0.1, 0.9, 0.1)
Expand All @@ -56,7 +81,7 @@ def gen_test_data(dist_class, weights: bool = False):
dist_class.dist.n_targets,
dist_class.dist.n_dist_param
)
if weights:
if weights and not censored:
weights = np.ones_like(labels[:, 0], dtype=labels.dtype).reshape(-1, 1)
dmatrix = xgb.DMatrix(predt, label=labels, weight=weights)
dist_class.set_base_margin(dmatrix)
Expand Down
6 changes: 6 additions & 0 deletions xgboostlss/distributions/LogNormal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch.distributions import LogNormal as LogNormal_Torch
from .distribution_utils import DistributionClass
from .censored_utils import CensoredMixin
from ..utils import *


Expand Down Expand Up @@ -65,3 +66,8 @@ def __init__(self,
distribution_arg_names=list(param_dict.keys()),
loss_fn=loss_fn
)


class CensoredLogNormal(CensoredMixin, LogNormal):
"""LogNormal distribution with interval-censoring support."""
pass
7 changes: 7 additions & 0 deletions xgboostlss/distributions/Weibull.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch.distributions import Weibull as Weibull_Torch
from .distribution_utils import DistributionClass
from .censored_utils import CensoredMixin
from ..utils import *


Expand Down Expand Up @@ -65,3 +66,9 @@ def __init__(self,
distribution_arg_names=list(param_dict.keys()),
loss_fn=loss_fn
)

class CensoredWeibull(CensoredMixin, Weibull):
"""
Weibull distribution class with interval censoring support.
"""
pass
74 changes: 74 additions & 0 deletions xgboostlss/distributions/censored_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
import torch
import xgboost as xgb
from typing import List, Tuple
from .distribution_utils import DistributionClass


class CensoredMixin(DistributionClass):
"""
Mixin to add interval-censoring support to a distribution.
Overrides objective_fn and metric_fn to dispatch to censored loss.
"""
def objective_fn(self, predt: np.ndarray, data: xgb.DMatrix):
lower = data.get_float_info("label_lower_bound")
upper = data.get_float_info("label_upper_bound")
if lower.size == 0 and upper.size == 0:
return super().objective_fn(predt, data)
if data.get_weight().size == 0:
# initialize weights as ones with correct shape
weights = torch.ones((lower.shape[0], 1), dtype=torch.as_tensor(lower).dtype).numpy()
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a tensor just to get its dtype and then converting back to numpy is inefficient. Consider using weights = np.ones((lower.shape[0], 1), dtype=lower.dtype) directly.

Suggested change
weights = torch.ones((lower.shape[0], 1), dtype=torch.as_tensor(lower).dtype).numpy()
weights = np.ones((lower.shape[0], 1), dtype=lower.dtype)

Copilot uses AI. Check for mistakes.
else:
weights = data.get_weight().reshape(-1, 1)
start_values = data.get_base_margin().reshape(-1, self.n_dist_param)[0, :].tolist()
predt_list, loss = self.get_params_loss_censored(
predt, start_values, lower, upper, requires_grad=True
)
grad, hess = self.compute_gradients_and_hessians(loss, predt_list, weights)
return grad, hess

def metric_fn(self, predt: np.ndarray, data: xgb.DMatrix):
lower = data.get_float_info("label_lower_bound")
upper = data.get_float_info("label_upper_bound")
if lower.size == 0 and upper.size == 0:
return super().metric_fn(predt, data)
start_values = data.get_base_margin().reshape(-1, self.n_dist_param)[0, :].tolist()
_, loss = self.get_params_loss_censored(
predt, start_values, lower, upper, requires_grad=False
)
return self.loss_fn, loss

def get_params_loss_censored(self,
predt: np.ndarray,
start_values: List[float],
lower: np.ndarray,
upper: np.ndarray,
requires_grad: bool = False,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""Compute loss for interval-censored data."""
predt_arr = predt.reshape(-1, self.n_dist_param)
# replace nan/inf
mask = np.isnan(predt_arr) | np.isinf(predt_arr)
predt_arr[mask] = np.take(start_values, np.where(mask)[1])
# convert to tensors
predt_list = [
torch.tensor(predt_arr[:, i].reshape(-1, 1), requires_grad=requires_grad)
for i in range(self.n_dist_param)
]
# transform parameters
params_transformed = [
fn(predt_list[i]) for i, fn in enumerate(self.param_dict.values())
]
# instantiate distribution
dist = self.distribution(**dict(zip(self.distribution_arg_names, params_transformed)))
# compute cdf bounds: convert lower & upper once to tensor with correct dtype
low = torch.as_tensor(lower, dtype=params_transformed[0].dtype).reshape(-1, 1)
hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1)
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent spacing: 'hi =' has two spaces before the equals sign while 'low =' on the previous line has one space. This should be consistent.

Suggested change
hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1)
hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1)

Copilot uses AI. Check for mistakes.
cdf_low = dist.cdf(low)
cdf_hi = dist.cdf(hi)
# interval mass & loss
mass = cdf_hi - cdf_low
log_density = dist.log_prob(low)
censored_inds = low != hi
loss = -torch.sum(torch.log(mass[censored_inds])) - torch.sum(log_density[~censored_inds])
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log density is computed using only the lower bound, but this should only be used for exact observations (non-censored data). For censored intervals where low != hi, this log_density value is incorrect and shouldn't contribute to the loss.

Suggested change
loss = -torch.sum(torch.log(mass[censored_inds])) - torch.sum(log_density[~censored_inds])
exact_inds = (low == hi)
log_density = dist.log_prob(low[exact_inds])
loss = -torch.sum(torch.log(mass[~exact_inds])) - torch.sum(log_density)

Copilot uses AI. Check for mistakes.
return predt_list, loss
Loading