Writing ml
ml 8 min read 20 January 2023

Differentiation in TensorFlow: GradientTape and Custom Training Loops

How TensorFlow's automatic differentiation works under the hood, when to use GradientTape over Keras fit(), and how to build custom training loops for research and production models.

Automatic Differentiation

Neural network training requires computing gradients of a scalar loss $L$ with respect to every parameter $\theta$. For a network with millions of parameters, doing this analytically is infeasible — automatic differentiation (autodiff) computes exact gradients algorithmically by applying the chain rule through a recorded computation graph.

TensorFlow uses reverse-mode autodiff (also called backpropagation): first execute the forward pass, recording the operations on a computational tape, then replay the tape in reverse to compute gradients. This is efficient when there are many parameters and few outputs (the typical case for loss minimization).

GradientTape

tf.GradientTape is TensorFlow’s explicit API for recording computations and computing gradients:

x = tf.Variable(3.0)

with tf.GradientTape() as tape:
    y = x ** 2  # Forward pass recorded on the tape

dy_dx = tape.gradient(y, x)  # dy/dx = 2x = 6.0

The with block defines the scope of recording. After the block, tape.gradient(target, sources) computes the gradient of target with respect to each source variable.

Watching Non-Variable Tensors

By default, GradientTape automatically watches tf.Variable objects but not plain tensors. To differentiate with respect to a tensor (e.g., an input, not a learned weight):

x = tf.constant(3.0)

with tf.GradientTape() as tape:
    tape.watch(x)
    y = x ** 2

dy_dx = tape.gradient(y, x)

This is useful for computing input gradients (e.g., for gradient-based input attribution methods like Integrated Gradients).

Higher-Order Derivatives

Nest GradientTape contexts to compute second-order derivatives:

x = tf.Variable(3.0)

with tf.GradientTape() as outer:
    with tf.GradientTape() as inner:
        y = x ** 3
    dy_dx = inner.gradient(y, x)      # First derivative: 3x²

d2y_dx2 = outer.gradient(dy_dx, x)   # Second derivative: 6x

Nested tapes are used in meta-learning (MAML), physics-informed networks, and any application requiring gradient-of-gradient computations.

Custom Training Loops

model.fit() is convenient but opaque. Custom training loops give full control over every step — necessary for research experiments, multi-loss objectives, gradient clipping strategies, or non-standard update rules.

The Standard Pattern

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

@tf.function  # Compile to a graph for speed
def train_step(x_batch, y_batch):
    with tf.GradientTape() as tape:
        predictions = model(x_batch, training=True)
        loss = loss_fn(y_batch, predictions)
    
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

for epoch in range(num_epochs):
    for x_batch, y_batch in dataset:
        loss = train_step(x_batch, y_batch)

The @tf.function decorator traces the function once and compiles it to a TensorFlow graph — removing Python overhead from every subsequent call. For complex training loops, this provides 2–5× speedups over eager execution.

Gradient Manipulation

The gradients list from tape.gradient() can be modified before apply_gradients:

Gradient clipping (prevents exploding gradients in RNNs and deep networks):

gradients, _ = tf.clip_by_global_norm(gradients, clip_norm=1.0)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Gradient accumulation (simulates larger batch sizes with limited memory):

accumulated_gradients = [tf.Variable(tf.zeros_like(v)) 
                         for v in model.trainable_variables]

for micro_batch in micro_batches:
    with tf.GradientTape() as tape:
        loss = loss_fn(model(micro_batch), labels) / num_micro_batches
    grads = tape.gradient(loss, model.trainable_variables)
    for acc_grad, grad in zip(accumulated_gradients, grads):
        acc_grad.assign_add(grad)

optimizer.apply_gradients(zip(accumulated_gradients, model.trainable_variables))

Custom per-layer learning rates:

# Different learning rates for different layer groups
optimizer_backbone = tf.keras.optimizers.Adam(1e-5)
optimizer_head = tf.keras.optimizers.Adam(1e-3)

backbone_vars = model.backbone.trainable_variables
head_vars = model.head.trainable_variables

with tf.GradientTape() as tape:
    loss = compute_loss(model(x), y)

backbone_grads = tape.gradient(loss, backbone_vars)
head_grads = tape.gradient(loss, head_vars)

optimizer_backbone.apply_gradients(zip(backbone_grads, backbone_vars))
optimizer_head.apply_gradients(zip(head_grads, head_vars))

When to Use GradientTape vs. model.fit()

Use CaseRecommendation
Standard supervised trainingmodel.fit() — simpler, handles callbacks, logging, distributed training
Multiple losses with different weightsCustom loop — control each loss contribution independently
Auxiliary tasks or multi-output lossesCustom loop — aggregate losses before calling tape.gradient
Gradient penalty (e.g., WGAN-GP)Custom loop — requires second-order gradients
Meta-learning or MAMLNested tapes in custom loop
Research: non-standard optimizersCustom loop — implement the update rule directly
Input attribution / saliency mapstape.watch on inputs, not training loop

Persistent Tapes

By default, GradientTape releases resources after the first gradient() call. If you need multiple gradient computations from the same tape (e.g., gradients of the same loss with respect to different sets of variables):

with tf.GradientTape(persistent=True) as tape:
    loss = compute_loss(model(x), y)

grad_a = tape.gradient(loss, model.layer_a.trainable_variables)
grad_b = tape.gradient(loss, model.layer_b.trainable_variables)

del tape  # Manually release resources

Persistent tapes hold the recorded forward pass in memory until explicitly deleted — use them only when necessary.

Practical Notes

Debugging gradients: None gradients from tape.gradient usually mean the variable is not connected to the computation graph (e.g., it was created outside the tape scope, or a non-differentiable operation broke the graph). Use tape.gradient(..., unconnected_gradients=tf.UnconnectedGradients.ZERO) to get zeros instead of None for disconnected variables.

Performance: eager execution (default in TF2) is convenient for debugging but slow for training. Always wrap the training step in @tf.function for production runs. The first call triggers tracing (slow); all subsequent calls execute the compiled graph (fast).

Mixed precision: combine with tf.keras.mixed_precision to run forward/backward passes in FP16 while keeping the optimizer state in FP32. The loss scaler (LossScaleOptimizer) handles the FP16 gradient underflow problem automatically.

tensorflow automatic-differentiation deep-learning custom-training ml
← All articles

Lets collaborate!

Whether you need a quantitative researcher, an machine learning systems builder, or a technical advisor — I'm available for select consulting engagements.

Get in Touch →