scikit-survival 0.11 featuring Random Survival Forests released
Today, I released a new version of scikit-survival which includes an implementation of Random Survival Forests. As it’s popular counterparts for classification and regression, a Random Survival Forest is an ensemble of tree-based learners. A Random Survival Forest ensures that individual trees are de-correlated by 1) building each tree on a different bootstrap sample of the original training data, and 2) at each node, only evaluate the split criterion for a randomly selected subset of features and thresholds. Predictions are formed by aggregating predictions of individual trees in the ensemble.
For a full list of changes in scikit-survival 0.11, please see the release notes.
The latest version can be downloaded via conda or pip. Pre-built conda packages are available for Linux, OSX and Windows via
conda install -c sebp scikit-survival
Alternatively, scikit-survival can be installed from source via pip:
pip install -U scikit-survival
Using Random Survival Forests
To demonstrate Random Survival Forest, I’m going to use data from the German Breast Cancer Study Group (GBSG-2) on the treatment of node-positive breast cancer patients. It contains data on 686 women and 8 prognostic factors:
- estrogen receptor (
- whether or not a hormonal therapy was administered (
- menopausal status (
- number of positive lymph nodes (
- progesterone receptor (
- tumor size (
- tumor grade (
The goal is to predict recurrence-free survival time.
The code to reproduce the results below is available in this notebook.
First, we need to load the data and transform it into numeric values.
X, y = load_gbsg2() grade_str = X.loc[:, "tgrade"].astype(object).values[:, np.newaxis] grade_num = OrdinalEncoder(categories=[["I", "II", "III"]]).fit_transform(grade_str) X_no_grade = X.drop("tgrade", axis=1) Xt = OneHotEncoder().fit_transform(X_no_grade) Xt = np.column_stack((Xt.values, grade_num)) feature_names = X_no_grade.columns.tolist() + ["tgrade"]
Next, the data is split into 75% for training and 25% for testing so we can determine how well our model generalizes.
X_train, X_test, y_train, y_test = train_test_split( Xt, y, test_size=0.25, random_state=random_state)
Several split criterion have been proposed in the past, but the most widespread one is based on the log-rank test, which you probably now from comparing survival curves among two or more groups. Using the training data, we fit a Random Survival Forest comprising 1000 trees.
rsf = RandomSurvivalForest(n_estimators=1000, min_samples_split=10, min_samples_leaf=15, max_features="sqrt", n_jobs=-1, random_state=random_state) rsf.fit(X_train, y_train)
We can check how well the model performs by evaluating it on the test data.
This gives a concordance index of 0.68, which is a good a value and matches the results reported in the Random Survival Forests paper.
For prediction, a sample is dropped down each tree in the forest until it reaches a terminal node. Data in each terminal is used to non-parametrically estimate the survival and cumulative hazard function using the Kaplan-Meier and Nelson-Aalen estimator, respectively. In addition, a risk score can be computed that represents the expected number of events for one particular terminal node. The ensemble prediction is simply the average across all trees in the forest.
Let’s first select a couple of patients from the test data according to the number of positive lymph nodes and age.
a = np.empty(X_test.shape, dtype=[("age", float), ("pnodes", float)]) a["age"] = X_test[:, 0] a["pnodes"] = X_test[:, 4] sort_idx = np.argsort(a, order=["pnodes", "age"]) X_test_sel = pd.DataFrame( X_test[np.concatenate((sort_idx[:3], sort_idx[-3:]))], columns=feature_names)
The predicted risk scores indicate that risk for the last three patients is quite a bit higher than that of the first three patients.
0 91.477609 1 102.897552 2 75.883786 3 170.502092 4 171.210066 5 148.691835 dtype: float64
We can have a more detailed insight by considering the predicted survival function. It shows that the biggest difference occurs roughly within the first 750 days.
surv = rsf.predict_survival_function(X_test_sel) for i, s in enumerate(surv): plt.step(rsf.event_times_, s, where="post", label=str(i)) plt.ylabel("Survival probability") plt.xlabel("Time in days") plt.grid(True) plt.legend()
Alternatively, we can also plot the predicted cumulative hazard function.
surv = rsf.predict_cumulative_hazard_function(X_test_sel) for i, s in enumerate(surv): plt.step(rsf.event_times_, s, where="post", label=str(i)) plt.ylabel("Cumulative hazard") plt.xlabel("Time in days") plt.grid(True) plt.legend()
Permutation-based Feature Importance
The implementation is based on scikit-learn’s Random Forest implementation and inherits many
features, such as building trees in parallel. What’s currently missing is feature importances
This is due to the way scikit-learn’s implementation computes importances. It relies on
a measure of impurity for each child node, and defines importance as the amount of
decrease in impurity due to a split. For traditional regression, impurity would be measured
by the variance, but for survival analysis there is no per-node impurity measure due to censoring.
Instead, one could use the magnitude of the log-rank test statistic as an importance measure,
but scikit-learn’s implementation doesn’t seem to allow this.
Fortunately, this is not a big concern though, as scikit-learn’s definition of feature importance is non-standard and differs from what Leo Breiman proposed in the original Random Forest paper. Instead, we can use permutation to estimate feature importance, which is preferred over scikit-learn’s definition. This is implemented in the ELI5 library, which is fully compatible with scikit-survival.
import eli5 from eli5.sklearn import PermutationImportance perm = PermutationImportance(rsf, n_iter=15, random_state=random_state) perm.fit(X_test, y_test) eli5.show_weights(perm, feature_names=feature_names)
|0.0676 ± 0.0229||pnodes|
|0.0206 ± 0.0139||age|
|0.0177 ± 0.0468||progrec|
|0.0086 ± 0.0098||horTh|
|0.0032 ± 0.0198||tsize|
|0.0032 ± 0.0060||tgrade|
|-0.0007 ± 0.0018||menostat|
|-0.0063 ± 0.0207||estrec|
The result shows that the number of positive lymph nodes (
pnodes) is by far the most important
feature. If its relationship to survival time is removed (by random shuffling),
the concordance index on the test data drops on average by 0.0676 points.
Again, this agrees with the results from the original
Random Survival Forests paper.