-
Notifications
You must be signed in to change notification settings - Fork 76
feat: support for censored likelihoods #91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
56caa99
3cad998
967cfa3
13e4202
c279bb5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| from ..utils import BaseTestClass | ||
| import pytest | ||
| import numpy as np | ||
| import torch | ||
| from xgboostlss.model import XGBoostLSS | ||
| from xgboostlss.distributions.LogNormal import CensoredLogNormal | ||
| from xgboostlss.distributions.Weibull import CensoredWeibull | ||
| from tests.utils import gen_test_data | ||
|
|
||
| class TestCensoredUnivariate(BaseTestClass): | ||
| @pytest.fixture(params=[CensoredLogNormal, CensoredWeibull]) | ||
| def model(self, request): | ||
| return XGBoostLSS(request.param()) | ||
|
|
||
| @pytest.mark.parametrize("weights", [False, True]) | ||
| def test_censored_objective_fn_shapes_and_values(self, model, weights): | ||
| predt, lower, upper, *rest = gen_test_data(model, weights=weights, censored=True) | ||
| dmat = rest[-1] | ||
| grad, hess = model.dist.objective_fn(predt, dmat) | ||
| assert isinstance(grad, np.ndarray) and isinstance(hess, np.ndarray) | ||
| assert grad.shape == predt.flatten().shape and hess.shape == predt.flatten().shape | ||
| assert not np.isnan(grad).any() and not np.isnan(hess).any() | ||
| assert not np.isinf(grad).any() and not np.isinf(hess).any() | ||
|
|
||
| @pytest.mark.parametrize("loss_fn", ["nll", "crps"]) | ||
| @pytest.mark.parametrize("weights", [False, True]) | ||
| def test_censored_metric_fn_shapes_and_values(self, model, loss_fn, weights): | ||
| model.dist.loss_fn = loss_fn | ||
| predt, lower, upper, *rest = gen_test_data(model, weights=weights, censored=True) | ||
| dmat = rest[-1] | ||
| name, loss = model.dist.metric_fn(predt, dmat) | ||
| assert name == loss_fn and isinstance(loss, torch.Tensor) | ||
| assert not torch.isnan(loss).any() and not torch.isinf(loss).any() | ||
|
|
||
| def test_metric_fn_exact_equals_uncensored(self, model): | ||
| predt, labels, *rest = gen_test_data(model, weights=False, censored=False) | ||
| dmat = rest[-1] | ||
| name_c, loss_c = model.dist.metric_fn(predt, dmat) | ||
| underlying_cls = model.dist.__class__.__mro__[2] | ||
| base_model = XGBoostLSS(underlying_cls()) | ||
| base_predt, base_labels, *base_rest = gen_test_data(base_model, weights=False, censored=False) | ||
| base_dmat = base_rest[-1] | ||
| name_b, loss_b = base_model.dist.metric_fn(base_predt, base_dmat) | ||
| assert name_c == name_b and torch.allclose(loss_c, loss_b) | ||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,74 @@ | ||||||||||
| import numpy as np | ||||||||||
| import torch | ||||||||||
| import xgboost as xgb | ||||||||||
| from typing import List, Tuple | ||||||||||
| from .distribution_utils import DistributionClass | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class CensoredMixin(DistributionClass): | ||||||||||
| """ | ||||||||||
| Mixin to add interval-censoring support to a distribution. | ||||||||||
| Overrides objective_fn and metric_fn to dispatch to censored loss. | ||||||||||
| """ | ||||||||||
| def objective_fn(self, predt: np.ndarray, data: xgb.DMatrix): | ||||||||||
| lower = data.get_float_info("label_lower_bound") | ||||||||||
| upper = data.get_float_info("label_upper_bound") | ||||||||||
| if lower.size == 0 and upper.size == 0: | ||||||||||
| return super().objective_fn(predt, data) | ||||||||||
| if data.get_weight().size == 0: | ||||||||||
| # initialize weights as ones with correct shape | ||||||||||
| weights = torch.ones((lower.shape[0], 1), dtype=torch.as_tensor(lower).dtype).numpy() | ||||||||||
|
||||||||||
| weights = torch.ones((lower.shape[0], 1), dtype=torch.as_tensor(lower).dtype).numpy() | |
| weights = np.ones((lower.shape[0], 1), dtype=lower.dtype) |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent spacing: 'hi =' has two spaces before the equals sign while 'low =' on the previous line has one space. This should be consistent.
| hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1) | |
| hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1) |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The log density is computed using only the lower bound, but this should only be used for exact observations (non-censored data). For censored intervals where low != hi, this log_density value is incorrect and shouldn't contribute to the loss.
| loss = -torch.sum(torch.log(mass[censored_inds])) - torch.sum(log_density[~censored_inds]) | |
| exact_inds = (low == hi) | |
| log_density = dist.log_prob(low[exact_inds]) | |
| loss = -torch.sum(torch.log(mass[~exact_inds])) - torch.sum(log_density) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using hardcoded index [2] in the MRO (Method Resolution Order) is fragile and could break if the inheritance hierarchy changes. Consider using a more explicit approach like checking class names or using hasattr to find the base distribution class.