Announcing scikit-survival – a Python library for survival analysis build on top of scikit-learn

I've meant to do this release for quite a while now and last week I finally had some time to package everything and update the dependencies. scikit-survival contains the majority of code I developed during my Ph.D.

About Survival Analysis

Survival analysis – also referred to as reliability analysis in engineering – refers to type of problem in statistics where the objective is to establish a connections between a set of measurements (often called features or covariates) and the time to an event. The name survival analysis originates from clinical research: in many clinical studies, one is interested in predicting the time to death, i.e., survival. Broadly speaking, survival analysis is a type of regression problem (one wants to predict a continuous value), but with a twist. Consider a clinical study, which investigates coronary heart disease and has been carried out over a 1 year period as in the figure below.

Patient A was lost to follow-up after three months with no recorded cardiovascular event, patient B experienced an event four and a half months after enrollment, patient D withdrew from the study two months after enrollment, and patient E did not experience any event before the study ended. Consequently, the exact time of a cardiovascular event could only be recorded for patients B and C; their records are uncensored. For the remaining patients it is unknown whether they did or did not experience an event after termination of the study. The only valid information that is available for patients A, D, and E is that they were event-free up to their last follow-up. Therefore, their records are censored.

Formally, each patient record consists of a set of covariates $x \in \mathbb{R}^d$ , and the time $t > 0$ when an event occurred or the time $c > 0$ of censoring. Since censoring and experiencing and event are mutually exclusive, it is common to define an event indicator $\delta \in \{0; 1\}$ and the observable survival time $y > 0$. The observable time $y$ of a right censored sample is defined as
\[ y = \min(t, c) =
t & \text{if } \delta = 1 , \\
c & \text{if } \delta = 0 ,

What is scikit-survival?

Recently, many methods from machine learning have been adapted for these kind of problems: random forest, gradient boosting, and support vector machine, many of which are only available for R, but not Python. Some of the traditional models are part of lifelines or statsmodels, but none of those libraries plays nice with scikit-learn, which is the quasi-standard machine learning framework for Python.

This is exactly where scikit-survival comes in. Models implemented in scikit-survival follow the scikit-learn interfaces. Thus, it is possible to use PCA from scikit-learn for dimensionality reduction and feed the low-dimensional representation to a survival model from scikit-survival, or cross-validate a survival model using the classes from scikit-learn. You can see an example of the latter in this notebook.

Download and Install

The source code is available at GitHub and can be installed via Anaconda (currently only for Linux) or pip.

conda install -c sebp scikit-survival

pip install scikit-survival

The API documentation is available here and scikit-survival ships with a couple of sample datasets from the medical domain to get you started.


Hey seems a great package,
I have some issues regarding the "y" which should be as mentionned a structured array with the first column being censors but ndarrays, panda dataframes or tuples doesn't work. I have seen in the exemple from the Veteran dataset that it seeems to be more like a ndarray of tuple, but neither work.
So my question is how from a dataframe or a ndarray, make the y.



Hey Sebastian,

Thanks for this library, looks really cool. You wrote that the lifelines library does not "play nice with scikit-learn." Can you elaborate on the problems that exist in lifelines/sklearn that motivate the use (or development) of scikit-survival?

Also, just a tip - you wrote in the comment above that "Unfortunately, it is not possible to use a pandas.DataFrame, because scikit-learn only works with numpy arrays." It's definitely possible to use scikit-learn with Pandas. My usual predictive workflow is to use a scikit-learn pipeline with custom transforms/steps, where each step accepts my data as a pandas dataframe, processes the data using pandas, and passes along a pandas dataframe to the next step. If you end the pipeline with a classifier/regressor that accepts pandas dataframes (like xgboost) then it all works very elegantly, including getting things like feature-importances reported by dataframe-column-name.

Thanks again for developing and publishing this!

Looks great but maddening that structured numpy arrays are used instead of pandas! Can this be fixed in a future release please?