Survival Analysis for Deep Learning Tutorial for TensorFlow 2
A while back, I posted the Survival Analysis for Deep Learning tutorial. This tutorial was written for TensorFlow 1 using the tf.estimators API. The changes between version 1 and the current TensorFlow 2 are quite significant, which is why the code does not run when using a recent TensorFlow version. Therefore, I created a new version of the tutorial that is compatible with TensorFlow 2. The text is basically identical, but the training and evaluation procedure changed.
Notes on porting to TensorFlow 2
A nice feature of TensorFlow 2 is that in order to write custom metrics (such as concordance index) for
TensorBoard, you don’t need to create a
Summary protocol buffer manually, instead
it suffices to call
tf.summary.scalar and pass it a name and float.
So instead of
from sksurv.metrics import concordance_index_censored from tensorflow.core.framework import summary_pb2 c_index_metric = concordance_index_censored(…) writer = tf.summary.FileWriterCache.get(output_dir) buf = summary_pb2.Summary(value=[summary_pb2.Summary.Value( tag="c-index", simple_value=c_index_metric)]) writer.add_summary(buf, global_step=global_step)
you can just do
from sksurv.metrics import concordance_index_censored with tf.summary.create_file_writer(output_dir): c_index_metric = concordance_index_censored(…) summary.scalar("c-index", c_index_metric, step=step)
Another feature that I liked is that you can now iterate over an instance of
directly access the tensors and their values. This is much more convenient than having to call
first, which gives you an iterator, which you call
get_next() on to get actual tensors.
Unfortunately, I also encountered some negatives when moving to TensorFlow 2.
First of all, there’s currently no officially supported way to produce a view of the executed Graph
that is identical to what you get with TensorFlow 1, unless you use the Keras training loop with
tf.summary.trace_export, which as described in this guide
sounds like it would produce the graph, however, using this approach you can only
view individual operations in TensorBoard, but you can’t inspect what’s the size of input and output tensors of an operation.
After searching for while, I eventually found the answer in an Stack overflow post, and, as it turns out, that is exactly what the
Another thing I found odd is that if you define your custom loss as a subclass of
tf.keras.losses.Loss, it insists
that there are only two inputs
y_pred. In the case of
Cox’s proportional hazards loss
the true label comprises an event indicator and an indicator matrix specifying which pairs in a batch are comparable.
Luckily, the contents of
y_pred don’t get checked, so you can just pass a list, but I would prefer to write something like
loss_fn(y_true_event=y_event, y_true_riskset=y_riskset, y_pred=pred_risk_score)
loss_fn(y_true=[y_event, y_riskset], y_pred=pred_risk_score)
Finally, although eager execution is now enabled by default, the code runs significantly faster in graph mode, i.e.
annotating your model’s call method with
@tf.function. I guess you are only supposed to use eager execution for debugging purposes.