scikit-survival 0.16 released

I am proud to announce the release if version 0.16.0 of scikit-survival, The biggest improvement in this release is that you can now change the evaluation metric that is used in estimators’ score method. This is particular useful for hyper-parameter optimization using scikit-learn’s GridSearchCV. You can now use as_concordance_index_ipcw_scorer, as_cumulative_dynamic_auc_scorer, or as_integrated_brier_score_scorer to adjust the score method to your needs. The example below illustrates how to use these in practice.

For a full list of changes in scikit-survival 0.16.0, please see the release notes.


Pre-built conda packages are available for Linux, macOS, and Windows via

 conda install -c sebp scikit-survival

Alternatively, scikit-survival can be installed from source following these instructions.

Hyper-Parameter Optimization with Alternative Metrics

The code below is also available as a notebook and can directly be executed by clicking

In this example, we are going to use the German Breast Cancer Study Group 2 dataset. We want to fit a Random Survival Forest and optimize it’s max_depth hyper-parameter using scikit-learn’s GridSearchCV.

Let’s begin by loading the data.

import numpy as np
from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import encode_categorical

gbsg_X, gbsg_y = load_gbsg2()
gbsg_X = encode_categorical(gbsg_X)

lower, upper = np.percentile(gbsg_y["time"], [10, 90])
gbsg_times = np.arange(lower, upper + 1)

Next, we create an instance of Random Survival Forest.

from sksurv.ensemble import RandomSurvivalForest

rsf_gbsg = RandomSurvivalForest(random_state=1)

We define that we want to evaluate the performance of each hyper-parameter configuration by 3-fold cross-validation.

from sklearn.model_selection import KFold

cv = KFold(n_splits=3, shuffle=True, random_state=1)

Next, we define the set of hyper-parameters to evaluate. Here, we search for the best value for max_depth between 1 and 10 (excluding). Note that we have to prefix max_depth with estimator__, because we are going to wrap the actual RandomSurvivalForest instance with one of the classes above.

cv_param_grid = {
    "estimator__max_depth": np.arange(1, 10, dtype=int),

Now, we can put all the pieces together and start searching for the best hyper-parameters that maximize concordance_index_ipcw.

from sklearn.model_selection import GridSearchCV
from sksurv.metrics import as_concordance_index_ipcw_scorer

gcv_cindex = GridSearchCV(
    as_concordance_index_ipcw_scorer(rsf_gbsg, tau=gbsg_times[-1]),
).fit(gbsg_X, gbsg_y)

The same process applies when optimizing hyper-parameters to maximize cumulative_dynamic_auc.

from sksurv.metrics import as_cumulative_dynamic_auc_scorer

gcv_iauc = GridSearchCV(
    as_cumulative_dynamic_auc_scorer(rsf_gbsg, times=gbsg_times),
).fit(gbsg_X, gbsg_y)

While as_concordance_index_ipcw_scorer and as_cumulative_dynamic_auc_scorer can be used with any estimator, as_integrated_brier_score_scorer is only available for estimators that provide the predict_survival_function method, which includes RandomSurvivalForest. If available, hyper-parameters that maximize the negative intergrated time-dependent Brier score will be selected, because a lower Brier score indicates better performance.

from sksurv.metrics import as_integrated_brier_score_scorer

gcv_ibs = GridSearchCV(
    as_integrated_brier_score_scorer(rsf_gbsg, times=gbsg_times),
).fit(gbsg_X, gbsg_y)

Finally, we can visualize the results of the grid search and compare the best performing hyper-parameter configurations (marked with a red dot).

import matplotlib.pyplot as plt

def plot_grid_search_results(gcv, ax, name):

_, axs = plt.subplots(3, 1, figsize=(6, 6), sharex=True)

plot_grid_search_results(gcv_cindex, axs[0], "c-index")
plot_grid_search_results(gcv_iauc, axs[1], "iAUC")
plot_grid_search_results(gcv_ibs, axs[2], "$-$IBS")
Results of hyper-parameter optimization.

Results of hyper-parameter optimization.

When optimizing for the concordance index, a high maximum depth works best, whereas the other metrics are best when choosing a maximum depth of 5 and 6, respectively.

Sebastian Pölsterl
AI Researcher

My research interests include machine learning for time-to-event analysis, causal inference and biomedical applications.