-
Notifications
You must be signed in to change notification settings - Fork 13
Max margin interval trees
There are few R packages available for interval regression, a machine learning problem which is important in genomics and medicine. Like usual regression, the goal is to learn a function that inputs a feature vector and outputs a real-valued prediction. Unlike usual regression, each output in the training set is an interval of acceptable values (rather than one value). In the terminology of the survival analysis literature, this is regression with “left, right, and interval censored” output/response data.
Max margin interval trees is a new nonlinear model for this problem (Drouin et al., 2017). A dynamic programming algorithm is used to find the optimal split point for each feature. The dynamic programming algorithm has been implemented in C++ and there are wrappers to this solver in R and Python (https://github.com/aldro61/mmit). The Python package includes a decision tree learner. However there is not yet an implementation of the decision tree learner in the R package. The goal of this project is to write an R package that implements the decision tree learner in R, using partykit.
The transformation forest model of Hothorn and Zeileis implements a decision tree model which can be trained on censored outputs (https://arxiv.org/abs/1701.02110). The trtf package on CRAN implements this nonlinear model. In a similar fashion, package LTRCtree (Survival Trees to Fit Left-Truncated and Right-Censored and Interval-Censored Survival Data) allow to deal with these type of responses. This package builds on infrastructure from partykit and can be studied as an example of how to use and extend partykit.
There are several linear models which can be trained on censored outputs.
- iregnet implements elastic net regularized Accelerated Failure Time models.
- penaltyLearning::IntervalRegressionCV implements a solver for the squared hinge loss and L1 regularization.
The Python module is organized as follows:
- mmit.core: This submodule implements an interface (mmit.core.compute_optimal_costs) to the dynamic programming algorithm (solver) used to find the optimal split point for each feature. The C++ code for the solver is located in this directory.
- mmit.learning: This submodule implements the MaxMarginIntervalTree class that allows to learn decision trees and compute predictions. This class is compatible with the Scikit-Learn API (see here). A tree can be learned using the fit method and predictions can be computed using the predict method. The tree learner uses multiple calls to the C++ solver to find best rule to split a node at each step in the learning process.
- mmit.metrics: This submodule implements metrics used to measure the accuracy of predictions with respect to the target intervals. The supported metrics are the zero-one loss and the mean squared error.
- mmit.model: This class implements the inner workings of the tree models that are learned by the MaxMarginIntervalTree class. The class RegressionTreeNode implements a tree node and is used recursively to construct trees. The TreeExporter class serves to export tree models in various formats. The only format that is currently supported is TikZ/LaTex.
- mmit.model_selection: The GridSearchCV class allows to train a MMIT using cross-validation to select the hyper-parameters (see below). Optionally, minimum cost-complexity pruning can be used to choose the optimal size for the tree. In this case, cross-validation is used to select the right size for the tree.
- mmit.pruning: This submodule implements minimum cost-complexity pruning (Breiman et al. 1984). This type of pruning is a regularization method that helps avoid overfitting when learning decision trees.
- mmit.tests: Implements some unit tests.
The following functionality is not currently implemented in the Python module, but contributions are welcome:
- Learning random forests of MMITs
- Using MMITs as voters in a boosting algorithm, such as Adaboost
The tree learner (MaxMarginIntervalTree) has some hyperparameters that we explain below. For more details, please refer to the MMIT paper.
- margin: The margin parameter of the loss function. This is a real-valued number that is greater or equal to zero.
- loss: The type of loss function used by the algorithm. Possible choices are “linear_hinge” for the linear hinge loss and “squared_hinge” for the squared hinge loss.
- max_depth: The maximum depth of the tree. This is a positive integer.
- min_samples_split: The minimum number of examples that must be in a node for it to be considered by the recursive partitioning algorithm. If there are too few examples, the node is not split and is a leaf in the final model. This is an integer with a value greater or equal to two.
When pruning is used (i.e., GridSearchCV([…], pruning=True)), the max_depth and min_samples_split parameters can be set to very large values. The pruning will take care of selecting the right size for the tree using the cross-validation procedure described by Breiman et al. (1984).
Implement the Max Margin Interval Tree model in the framework of partykit: mmit R package with
- mmit() to train a tree model (for a given set of hyper-parameters).
- cv.mmit() to train a tree model using K-fold cross-validation to select hyper-parameters and optionally, to perform minimum cost-complexity pruning, which uses cross-validation to select the best size for the tree.
- mmif() for random forest.
- mmitboost() to learn an Adaboost regressor using MMIT estimators.
- documentation and tests for each function.
- dev on github with code quality assurance (code coverage and travis for testing).
- vignette to explain typical package usage.
- Maybe a new tree visualization with plots of the leaves that are specifically designed for interval censored outputs.
Breiman, L., Friedman, J., Stone, C. J., & Olshen, R. A. (1984). Classification and regression trees. CRC press.
Drouin, A., Hocking, T.D. & Laviolette, F. (2017). Maximum Margin Interval Trees. Proceedings of the 31st Conference on Neural Information Processing Systems (NIPS 2017), Long Beach, CA, USA. (link)
This project will provide an R implementation of the max margin interval tree model for interval regression, which currently only has a Python implementation.
Students, please contact mentors below after completing at least one of the tests below.
- Alexandre Drouin <[email protected]> is a co-author of the Max Margin Interval Trees paper, and author of the Python mmit module and C++ code.
- Torsten Hothorn <[email protected]> is an expert at implementing decision tree algos in R – he is the author of the trtf/partykit packages.
- Backup mentor: Toby Hocking <[email protected]> is a co-author of the Max Margin Interval Trees paper, author of the R package penaltyLearning (which implements a linear interval regression algo), and mentor of the students that implemented the iregnet package (GSOC2016-2017).
Students, please do one or more of the following tests before contacting the mentors above. The best would be if you could write comments on your code in an Rmd, and then post that to a github repo for us to read/review.
- Easy: run some R code that shows you know how to train and test a decision tree model (rpart, partykit, etc). Bonus points if you can get trtf running for an interval regression problem, for example data(neuroblastomaProcessed, package=”penaltyLearning”). Use 5-fold cross-validation to compare the learned decision tree models to a trivial baseline (which ignores the features and just learns the most likely prediction based on the train labels and always predicts that).
- Medium: Read the partykit vignette to learn how to implement a new tree model using the partykit framework. Use it to re-implement a simple version of Breiman’s CART algorithm (rpart R package). Demonstrate the equivalence of your code and rpart on the data set in example(rpart).
- Hard: Read the help page of the survival::survreg function, which can be used to fit a linear model for censored outputs. Use it as a sub-routine to implement a (slow) regression tree for interval censored output data. Search for the best possible split over all features – the best split is the one that maximizes the logLik of the survreg model. Demonstrate that your regression tree model works on a small subset of data(neuroblastomaProcessed, package=”penaltyLearning”).
Students, please post a link to your test results here.
Parismita Das code | documentation