scikit-survival 0.21.0 released
Today marks the release of scikit-survival 0.21.0. This release features some exciting new features and significant performance improvements:
- Pointwise confidence intervals for the Kaplan-Meier estimator.
- Early stopping in GradientBoostingSurvivalAnalysis.
- Improved performance of fitting SurvivalTree and RandomSurvivalForest.
- Reduced memory footprint of concordance_index_censored.
Pointwise Confidence Intervals for the Kaplan-Meier Estimator
kaplan_meier_estimator()
can now estimate pointwise confidence intervals by specifying the conf_type
parameter.
import matplotlib.pyplot as plt
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.nonparametric import kaplan_meier_estimator
_, y = load_veterans_lung_cancer()
time, survival_prob, conf_int = kaplan_meier_estimator(
y["Status"], y["Survival_in_days"], conf_type="log-log"
)
plt.step(time, survival_prob, where="post")
plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step="post")
plt.ylim(0, 1)
plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
Kaplan-Meier curve with pointwise confidence intervals.
Early Stopping in GradientBoostingSurvivalAnalysis
Early stopping enables us to determine when the model is sufficiently complex.
This is usually done by continuously evaluating the model on held-out data.
For GradientBoostingSurvivalAnalysis,
the easiest way to achieve this is by setting n_iter_no_change
and
optionally validation_fraction
(defaults to 0.1).
from sksurv.datasets import load_whas500
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
X, y = load_whas500()
model = GradientBoostingSurvivalAnalysis(
n_estimators=1000, max_depth=2, subsample=0.8, n_iter_no_change=10, random_state=0,
)
model.fit(X, y)
print(model.n_estimators_)
In this example, model.n_estimators_
indicates that fitting stopped after 73 iterations,
instead of the maximum 1000 iterations.
Alternatively, one can provide a custom callback function to the
fit
method. If the callback returns True
, training is stopped.
model = GradientBoostingSurvivalAnalysis(
n_estimators=1000, max_depth=2, subsample=0.8, random_state=0,
)
def early_stopping_monitor(iteration, model, args):
"""Stop training if there was no improvement in the last 10 iterations"""
start = max(0, iteration - 10)
end = iteration + 1
oob_improvement = model.oob_improvement_[start:end]
return all(oob_improvement < 0)
model.fit(X, y, monitor=early_stopping_monitor)
print(model.n_estimators_)
In the example above, early stopping is determined by checking
the last 10 entries of the oob_improvement_
attribute.
It contains the improvement in loss on the out-of-bag samples
relative to the previous iteration.
This requires setting subsample
to a value smaller 1, here 0.8.
Using this approach, training stopped after 114 iterations.
Improved Performance of SurvivalTree and RandomSurvivalForest
Another exciting feature of scikit-survival 0.21.0 is due to a re-write of the training routine of SurvivalTree. This results in roughly 3x faster training times.
Runtime comparison of fitting SurvivalTree.
The plot above compares the time required to fit a single SurvivalTree on data with 25 features and varying number of samples. The performance difference becomes notable for data with 1000 samples and above.
Note that this improvement also speeds-up fitting RandomSurvivalForest and ExtraSurvivalTrees.
Improved concordance index
Another performance improvement is due to Christine Poerschke who significantly reduced the memory footprint of concordance_index_censored(). With scikit-survival 0.21.0, memory usage scales linear, instead of quadratic, in the number of samples, making performance evaluation on large datasets much more manageable.
For a full list of changes in scikit-survival 0.21.0, please see the release notes.
Install
Pre-built conda packages are available for Linux, macOS (Intel), and Windows, either
via pip:
pip install scikit-survival
or via conda
conda install -c sebp scikit-survival