Survival Analysis for Deep Learning Tutorial for TensorFlow 2

A while back, I posted the Survival Analysis for Deep Learning tutorial. This tutorial was written for TensorFlow 1 using the tf.estimators API. The changes between version 1 and the current TensorFlow 2 are quite significant, which is why the code does not run when using a recent TensorFlow version. Therefore, I created a new version of the tutorial that is compatible with TensorFlow 2. The text is basically identical, but the training and evaluation procedure changed.

The complete notebook is available on GitHub, or you can run it directly using Google Colaboratory.

Notes on porting to TensorFlow 2

A nice feature of TensorFlow 2 is that in order to write custom metrics (such as concordance index) for TensorBoard, you don’t need to create a Summary protocol buffer manually, instead it suffices to call tf.summary.scalar and pass it a name and float. So instead of

from sksurv.metrics import concordance_index_censored
from tensorflow.core.framework import summary_pb2

c_index_metric = concordance_index_censored(…)[0]

writer = tf.summary.FileWriterCache.get(output_dir)
buf = summary_pb2.Summary(value=[summary_pb2.Summary.Value(
    tag="c-index", simple_value=c_index_metric)])
writer.add_summary(buf, global_step=global_step)

you can just do

from sksurv.metrics import concordance_index_censored

with tf.summary.create_file_writer(output_dir):
    c_index_metric = concordance_index_censored(…)[0]
    summary.scalar("c-index", c_index_metric, step=step)

Another feature that I liked is that you can now iterate over an instance of tf.data.Dataset and directly access the tensors and their values. This is much more convenient than having to call make_one_shot_iterator first, which gives you an iterator, which you call get_next() on to get actual tensors.

Unfortunately, I also encountered some negatives when moving to TensorFlow 2. First of all, there’s currently no officially supported way to produce a view of the executed Graph that is identical to what you get with TensorFlow 1, unless you use the Keras training loop with the TensorBoard callback. There’s tf.summary.trace_export, which as described in this guide sounds like it would produce the graph, however, using this approach you can only view individual operations in TensorBoard, but you can’t inspect what’s the size of input and output tensors of an operation. After searching for while, I eventually found the answer in an Stack overflow post, and, as it turns out, that is exactly what the TensorBoard callback is doing.

Another thing I found odd is that if you define your custom loss as a subclass of tf.keras.losses.Loss, it insists that there are only two inputs y_true and y_pred. In the case of Cox’s proportional hazards loss the true label comprises an event indicator and an indicator matrix specifying which pairs in a batch are comparable. Luckily, the contents of y_pred don’t get checked, so you can just pass a list, but I would prefer to write something like

loss_fn(y_true_event=y_event, y_true_riskset=y_riskset, y_pred=pred_risk_score)

instead of

loss_fn(y_true=[y_event, y_riskset], y_pred=pred_risk_score)

Finally, although eager execution is now enabled by default, the code runs significantly faster in graph mode, i.e. annotating your model’s call method with @tf.function. I guess you are only supposed to use eager execution for debugging purposes.

Avatar
Sebastian Pölsterl
AI Researcher

My research interests include machine learning for time-to-event analysis, causal inference and biomedical applications.

Related