Skip to content

Hist training with checkpointing is non-deterministic based on subsample #10324

@andrew-esteban-imc

Description

@andrew-esteban-imc

Hi there,

We have found that despite setting a seed for our hist training, we get non-deterministic results when resuming training from a checkpoint. A reproducer can be seen below:

import numpy as np
import xgboost as xgb
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from xgboost.callback import TrainingCallback

# Generate random sample data
X, y = make_classification(n_samples=1000, n_features=20, n_informative=10, n_redundant=10, random_state=42)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


# Ensure model has correct epoch upon restart
class XgbCheckpointCallback(TrainingCallback):
    def __init__(
            self,
            start_epoch: int
    ):
        super().__init__()
        self.start_epoch = start_epoch

    def before_training(self, model: xgb.Booster):
        self._prev_update = model.update

        def update(
                dtrain, iteration: int, fobj=None
        ) -> None:
            return self._prev_update(dtrain, iteration + self.start_epoch, fobj)

        model.update = update
        return model

    def after_training(self, model):
        model.update = self._prev_update
        return model


# Set the parameters for XGBoost training
params = {
    'silent': True,
    "tree_method": "hist",
    "seed": 1,
    "base_score": -0.9,
    "max_depth": 3,
    "learning_rate": 0.02,
    "lambda": 500,
    "subsample": 0.9,
}

# Train the XGBoost model using the hist algorithm
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

# Train the model for the first run
bst1 = xgb.train(params, dtrain, 10, evals=[(dtrain, 'train'), (dtest, 'test')], early_stopping_rounds=5, feval=None)

# Save the model after the first run
bst1.save_model('first_run_model')

# Train the model for the second run
bst2 = xgb.train(params, dtrain, 5, evals=[(dtrain, 'train'), (dtest, 'test')], early_stopping_rounds=5, feval=None)

# Save the model after the second run
bst2.save_model('second_run_model')

# Resume training from the second run and run for 5 more epochs
bst3 = xgb.train(params, dtrain, 5, evals=[(dtrain, 'train'), (dtest, 'test')], early_stopping_rounds=5, feval=None,
                 xgb_model=bst2, callbacks=[XgbCheckpointCallback(5)])

# Save the model after the third run
bst3.save_model('third_run_model')

# Make predictions on the test set for each run
preds1 = bst1.predict(dtest)
preds2 = bst2.predict(dtest)
preds3 = bst3.predict(dtest)

# Evaluate the model
print("Test set accuracy for run 1: {:.5f}".format(np.mean(preds1)))
print("Test set accuracy for run 3: {:.5f}".format(np.mean(preds3)))

We make use of XgbCheckpointCallback to fix a similar issue whereby restarting from a checkpoint ignores the epoch the checkpoint got up to. You can remove it, but then setting any of the colsample_* params to a value below 1.0 will produce the same issue.

When tree_method is set to exact, the uninterrupted model and the interrupted model are identical. When tree_method is set to hist and subsample is set to 1.0, they are also identical. When running with hist and subsample < 1.0 however, the results differ.

I've seen #6711, but that seems to be somewhat different in nature.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions