Skip to content

Conversation

@John-Curcio
Copy link

@John-Curcio John-Curcio commented Jul 2, 2025

  • add CensoredMixin class to extend fitting univariate distributions to censored data
  • add CensoredLogNormal and CensoredWeibull
  • test cases for censored data

xgb.DMatrix already has label_lower_bound, label_upper_bound for the user to specify right-, left-, or interval-censored data. This PR adds a CensoredMixin class which simply overrides objective_fn, metric_fn to accommodate censored data. So to fit a LogNormal distribution to such a dataset, just use CensoredLogNormal instead of LogNormal.

I've added CensoredLogNormal and CensoredWeibull.

I'm happy to further update docs/add examples

@StatMixedML
Copy link
Owner

@John-Curcio Thanks for opening the PR and your effort!

Would need some time to review it.

@StatMixedML StatMixedML requested a review from Copilot August 8, 2025 08:15
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for fitting censored data by introducing a CensoredMixin class that extends univariate distributions to handle interval-censored observations. The mixin overrides objective_fn and metric_fn to compute likelihood functions using cumulative distribution functions (CDFs) for censored intervals.

  • Adds CensoredMixin class with censored likelihood computation
  • Implements CensoredLogNormal and CensoredWeibull distribution classes
  • Adds comprehensive test coverage for censored data functionality

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
xgboostlss/distributions/censored_utils.py Core CensoredMixin class implementing censored likelihood functions
xgboostlss/distributions/Weibull.py Adds CensoredWeibull class inheriting from CensoredMixin and Weibull
xgboostlss/distributions/LogNormal.py Adds CensoredLogNormal class inheriting from CensoredMixin and LogNormal
tests/utils.py Extends test data generation to support censored data scenarios
tests/test_distribution_utils/test_censored_utils.py Test suite validating censored distribution functionality

mass = cdf_hi - cdf_low
log_density = dist.log_prob(low)
censored_inds = low != hi
loss = -torch.sum(torch.log(mass[censored_inds])) - torch.sum(log_density[~censored_inds])
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log density is computed using only the lower bound, but this should only be used for exact observations (non-censored data). For censored intervals where low != hi, this log_density value is incorrect and shouldn't contribute to the loss.

Suggested change
loss = -torch.sum(torch.log(mass[censored_inds])) - torch.sum(log_density[~censored_inds])
exact_inds = (low == hi)
log_density = dist.log_prob(low[exact_inds])
loss = -torch.sum(torch.log(mass[~exact_inds])) - torch.sum(log_density)

Copilot uses AI. Check for mistakes.
return super().objective_fn(predt, data)
if data.get_weight().size == 0:
# initialize weights as ones with correct shape
weights = torch.ones((lower.shape[0], 1), dtype=torch.as_tensor(lower).dtype).numpy()
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a tensor just to get its dtype and then converting back to numpy is inefficient. Consider using weights = np.ones((lower.shape[0], 1), dtype=lower.dtype) directly.

Suggested change
weights = torch.ones((lower.shape[0], 1), dtype=torch.as_tensor(lower).dtype).numpy()
weights = np.ones((lower.shape[0], 1), dtype=lower.dtype)

Copilot uses AI. Check for mistakes.
predt, labels, *rest = gen_test_data(model, weights=False, censored=False)
dmat = rest[-1]
name_c, loss_c = model.dist.metric_fn(predt, dmat)
underlying_cls = model.dist.__class__.__mro__[2]
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using hardcoded index [2] in the MRO (Method Resolution Order) is fragile and could break if the inheritance hierarchy changes. Consider using a more explicit approach like checking class names or using hasattr to find the base distribution class.

Suggested change
underlying_cls = model.dist.__class__.__mro__[2]
# Find the first base class in the MRO that is not a censored distribution and not 'object'
underlying_cls = next(
cls for cls in model.dist.__class__.__mro__
if cls is not model.dist.__class__ and not cls.__name__.startswith("Censored") and cls is not object
)

Copilot uses AI. Check for mistakes.
dist = self.distribution(**dict(zip(self.distribution_arg_names, params_transformed)))
# compute cdf bounds: convert lower & upper once to tensor with correct dtype
low = torch.as_tensor(lower, dtype=params_transformed[0].dtype).reshape(-1, 1)
hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1)
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent spacing: 'hi =' has two spaces before the equals sign while 'low =' on the previous line has one space. This should be consistent.

Suggested change
hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1)
hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1)

Copilot uses AI. Check for mistakes.
@StatMixedML StatMixedML mentioned this pull request Aug 8, 2025
@StatMixedML
Copy link
Owner

@claude Review this PR

@claude
Copy link

claude bot commented Aug 8, 2025

Claude encountered an error —— View job

Failed with exit code 128

I'll analyze this and get back to you.

@StatMixedML
Copy link
Owner

@John-Curcio CAn you please look into these

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants