scikit-survival 0.16 released
I am proud to announce the release if version 0.16.0 of scikit-survival,
The biggest improvement in this release is that you can now
change the evaluation metric that is used in estimators’ score
method.
This is particular useful
for hyper-parameter optimization using scikit-learn’s GridSearchCV
.
You can now use as_concordance_index_ipcw_scorer,
as_cumulative_dynamic_auc_scorer, or
as_integrated_brier_score_scorer to adjust the
score
method to your needs.
The
example below
illustrates how to use these in practice.
For a full list of changes in scikit-survival 0.16.0, please see the release notes.
Installation
Pre-built conda packages are available for Linux, macOS, and Windows via
conda install -c sebp scikit-survival
Alternatively, scikit-survival can be installed from source following these instructions.
Hyper-Parameter Optimization with Alternative Metrics
The code below is also available as a notebook and can directly be executed by clicking
In this example, we are going to use the
German Breast Cancer Study Group 2
dataset.
We want to fit a
Random Survival Forest
and optimize it’s max_depth
hyper-parameter using scikit-learn’s
GridSearchCV.
Let’s begin by loading the data.
import numpy as np
from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import encode_categorical
gbsg_X, gbsg_y = load_gbsg2()
gbsg_X = encode_categorical(gbsg_X)
lower, upper = np.percentile(gbsg_y["time"], [10, 90])
gbsg_times = np.arange(lower, upper + 1)
Next, we create an instance of Random Survival Forest.
from sksurv.ensemble import RandomSurvivalForest
rsf_gbsg = RandomSurvivalForest(random_state=1)
We define that we want to evaluate the performance of each hyper-parameter configuration by 3-fold cross-validation.
from sklearn.model_selection import KFold
cv = KFold(n_splits=3, shuffle=True, random_state=1)
Next, we define the set of hyper-parameters to evaluate. Here, we search for the best value for max_depth
between 1 and 10 (excluding). Note that we have to prefix max_depth
with estimator__
, because we are going to wrap the actual RandomSurvivalForest
instance with one of the classes above.
cv_param_grid = {
"estimator__max_depth": np.arange(1, 10, dtype=int),
}
Now, we can put all the pieces together and start searching for the best hyper-parameters that maximize concordance_index_ipcw.
from sklearn.model_selection import GridSearchCV
from sksurv.metrics import as_concordance_index_ipcw_scorer
gcv_cindex = GridSearchCV(
as_concordance_index_ipcw_scorer(rsf_gbsg, tau=gbsg_times[-1]),
param_grid=cv_param_grid,
cv=cv,
).fit(gbsg_X, gbsg_y)
The same process applies when optimizing hyper-parameters to maximize cumulative_dynamic_auc.
from sksurv.metrics import as_cumulative_dynamic_auc_scorer
gcv_iauc = GridSearchCV(
as_cumulative_dynamic_auc_scorer(rsf_gbsg, times=gbsg_times),
param_grid=cv_param_grid,
cv=cv,
).fit(gbsg_X, gbsg_y)
While as_concordance_index_ipcw_scorer
and as_cumulative_dynamic_auc_scorer
can be used with any estimator,
as_integrated_brier_score_scorer is only available for estimators that provide the predict_survival_function
method, which includes RandomSurvivalForest
. If available, hyper-parameters that maximize the negative intergrated time-dependent Brier score will be selected, because a lower Brier score indicates better performance.
from sksurv.metrics import as_integrated_brier_score_scorer
gcv_ibs = GridSearchCV(
as_integrated_brier_score_scorer(rsf_gbsg, times=gbsg_times),
param_grid=cv_param_grid,
cv=cv,
).fit(gbsg_X, gbsg_y)
Finally, we can visualize the results of the grid search and compare the best performing hyper-parameter configurations (marked with a red dot).
import matplotlib.pyplot as plt
def plot_grid_search_results(gcv, ax, name):
ax.errorbar(
x=gcv.cv_results_["param_estimator__max_depth"].filled(),
y=gcv.cv_results_["mean_test_score"],
yerr=gcv.cv_results_["std_test_score"],
)
ax.plot(
gcv.best_params_["estimator__max_depth"],
gcv.best_score_,
'ro',
)
ax.set_ylabel(name)
ax.yaxis.grid(True)
_, axs = plt.subplots(3, 1, figsize=(6, 6), sharex=True)
axs[-1].set_xlabel("max_depth")
plot_grid_search_results(gcv_cindex, axs[0], "c-index")
plot_grid_search_results(gcv_iauc, axs[1], "iAUC")
plot_grid_search_results(gcv_ibs, axs[2], "$-$IBS")
When optimizing for the concordance index, a high maximum depth works best, whereas the other metrics are best when choosing a maximum depth of 5 and 6, respectively.