Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion xgboostlss/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def hyper_opt(
dtrain: DMatrix,
num_boost_round=500,
nfold=10,
folds=None,
early_stopping_rounds=20,
max_minutes=10,
n_trials=None,
Expand All @@ -340,6 +341,13 @@ def hyper_opt(
Number of boosting iterations.
nfold: int
Number of folds in CV.
folds: a KFold or StratifiedKFold instance or list of fold indices
Sklearn KFolds or StratifiedKFolds object.
Alternatively may explicitly pass sample indices for each fold.
For ``n`` folds, **folds** should be a length ``n`` list of tuples.
Each tuple is ``(in,out)`` where ``in`` is a list of indices to be used
as the training samples for the ``n`` th fold and ``out`` is a list of
indices to be used as the testing samples for the ``n`` th fold.
early_stopping_rounds: int
Activates early stopping. Cross-Validation metric (average of validation
metric computed over CV folds) needs to improve at least once in
Expand Down Expand Up @@ -417,7 +425,8 @@ def objective(trial):
early_stopping_rounds=early_stopping_rounds,
callbacks=[pruning_callback],
seed=seed,
verbose_eval=False
verbose_eval=False,
folds=folds,
)

# Add the optimal number of rounds
Expand Down