Survival Analysis for Deep Learning Tutorial for TensorFlow 2
A while back, I posted the Survival Analysis for Deep Learning tutorial. This tutorial was written for TensorFlow 1 using the tf.estimators API. The changes between version 1 and the current TensorFlow 2 are quite significant, which is why the code does not run when using a recent TensorFlow version. Therefore, I created a new version of the tutorial that is compatible with TensorFlow 2. The text is basically identical, but the training and evaluation procedure changed.
The complete notebook is available on GitHub, or you can run it directly using Google Colaboratory.
Notes on porting to TensorFlow 2
A nice feature of TensorFlow 2 is that in order to write custom metrics (such as concordance index) for
TensorBoard, you don’t need to create a Summary
protocol buffer manually, instead
it suffices to call tf.summary.scalar
and pass it a name and float.
So instead of
from sksurv.metrics import concordance_index_censored
from tensorflow.core.framework import summary_pb2
c_index_metric = concordance_index_censored(…)[0]
writer = tf.summary.FileWriterCache.get(output_dir)
buf = summary_pb2.Summary(value=[summary_pb2.Summary.Value(
tag="c-index", simple_value=c_index_metric)])
writer.add_summary(buf, global_step=global_step)
you can just do
from sksurv.metrics import concordance_index_censored
with tf.summary.create_file_writer(output_dir):
c_index_metric = concordance_index_censored(…)[0]
summary.scalar("c-index", c_index_metric, step=step)
Another feature that I liked is that you can now iterate over an instance of tf.data.Dataset
and
directly access the tensors and their values. This is much more convenient than having to call make_one_shot_iterator
first, which gives you an iterator, which you call get_next()
on to get actual tensors.
Unfortunately, I also encountered some negatives when moving to TensorFlow 2.
First of all, there’s currently no officially supported way to produce a view of the executed Graph
that is identical to what you get with TensorFlow 1, unless you use the Keras training loop with
the TensorBoard
callback.
There’s tf.summary.trace_export
, which as described in this guide
sounds like it would produce the graph, however, using this approach you can only
view individual operations in TensorBoard, but you can’t inspect what’s the size of input and output tensors of an operation.
After searching for while, I eventually found the answer in an Stack overflow post, and, as it turns out, that is exactly what the TensorBoard
callback
is doing.
Another thing I found odd is that if you define your custom loss as a subclass of tf.keras.losses.Loss
, it insists
that there are only two inputs y_true
and y_pred
. In the case of
Cox’s proportional hazards loss
the true label comprises an event indicator and an indicator matrix specifying which pairs in a batch are comparable.
Luckily, the contents of y_pred
don’t get checked, so you can just pass a list, but I would prefer to write something like
loss_fn(y_true_event=y_event, y_true_riskset=y_riskset, y_pred=pred_risk_score)
instead of
loss_fn(y_true=[y_event, y_riskset], y_pred=pred_risk_score)
Finally, although eager execution is now enabled by default, the code runs significantly faster in graph mode, i.e.
annotating your model’s call method with @tf.function
. I guess you are only supposed to use eager execution for debugging purposes.