Writing ml
ml 12 min read 1 February 2024

Neural Network Training Playbook

A practitioner's guide to training neural networks — from initialization and optimization to regularization, debugging, and the decisions that actually determine whether your model converges.

Training a neural network is deterministic in principle and chaotic in practice. The same architecture, same optimizer, and same data can produce wildly different results depending on initialization, learning rate schedule, batch size, and a dozen other choices. This playbook is the consolidated set of decisions that actually matter.

Before Training: Set Up the Experiment Correctly

Fix randomness. Set seeds for numpy, torch, and Python’s random at the start of every experiment. Unreproducible results make debugging impossible.

Verify your data pipeline first. Load a batch, print shapes, visualize samples. Most “the model isn’t learning” bugs are data pipeline bugs, not model bugs. Check that your labels are aligned with your inputs.

Establish a baseline. Before your full model, train the simplest possible thing that could work: logistic regression, a single linear layer, a 2-layer MLP. If the simple thing doesn’t work, the complex thing definitely won’t.

Initialization

Poor initialization can mean the network spends thousands of steps recovering from a bad starting point, or never recovers at all.

Xavier / Glorot initialization: Designed for tanh / sigmoid activations. Weights sampled from a uniform or normal distribution scaled by 1/sqrt(fan_in).

He / Kaiming initialization: Designed for ReLU activations. Weights sampled from a normal distribution scaled by sqrt(2/fan_in). The sqrt(2) factor compensates for ReLU zeroing half of activations.

Practical rule: Use Kaiming init for ReLU-based networks. It’s the default in PyTorch for Conv and Linear layers — verify this is what you’re getting by checking the layer initialization code.

Zero initialization trap: Biases can be initialized to zero. Weights should not — zero weights make all neurons identical, and gradient updates will never break the symmetry.

The Learning Rate Is the Most Important Hyperparameter

Everything else is secondary to getting the learning rate right.

Too high: The loss bounces around or diverges. You’ll see NaN losses or oscillating validation curves.

Too low: Training converges, but slowly — and may settle in a poor local minimum.

The right range: Start with 1e-3 for Adam, 1e-1 for SGD. These are rough starting points; use a learning rate finder to verify.

Learning rate finder: Train for one epoch, increasing the learning rate exponentially from a very small value (1e-7) to a large value (10). Plot loss vs. learning rate. The optimal learning rate is just before the loss starts increasing sharply.

Learning rate schedules:

Batch Size and Its Effects

Larger batch sizes:

Smaller batch sizes:

Rule of thumb: Start with batch size 32 or 64. If you scale to larger batches, scale the learning rate proportionally (linear scaling rule: 2x batch size → 2x learning rate) and use a longer warmup.

For Transformers: batch size is often constrained by memory. Use gradient accumulation to simulate larger effective batch sizes.

Optimizers

Adam: The default choice. Adaptive learning rates per parameter. Works well without extensive tuning. Use lr=1e-3, betas=(0.9, 0.999), eps=1e-8.

AdamW: Adam with proper weight decay. Almost always prefer over Adam when using L2 regularization — standard Adam’s weight decay implementation is subtly wrong.

SGD with momentum: Often better than Adam for final model performance on image tasks, but requires more tuning. Requires a learning rate schedule.

Lion: A newer optimizer that can match or beat AdamW with less memory overhead. Worth trying if memory is a constraint.

Practical guidance: Use AdamW for most tasks. If you need the absolute best performance and have time to tune, compare against SGD with cosine annealing.

Regularization

L2 weight decay: Penalizes large weight values. Implemented correctly in AdamW. Typical values: 1e-4 to 1e-2.

Dropout: Randomly zero out activations during training. Forces the network to learn redundant representations. Use after linear layers and attention layers (not usually after convolutions). Rate: 0.1–0.5 depending on network depth and data size.

Batch normalization: Normalizes layer activations to have zero mean and unit variance within a batch. Reduces internal covariate shift, allows higher learning rates, provides implicit regularization. Apply before the activation function.

Layer normalization: Like batch norm but normalizes across the feature dimension rather than the batch dimension. Required for Transformers (batch norm breaks with variable-length sequences).

Data augmentation: Extends the effective training set size by applying random transformations to inputs. Domain-specific — for images: random crops, horizontal flips, color jitter. For time-series: time warping, magnitude scaling, noise injection.

Early stopping: Stop training when validation loss stops improving (with a patience parameter). Simple and effective. Keep the checkpoint from the epoch with best validation loss, not the last epoch.

Debugging a Training Run That Isn’t Working

The debugging process in order:

  1. Is the data correct? Print a few batches. Visualize them. Check label distribution. Verify that preprocessing is deterministic.

  2. Does the model overfit a small batch? Take 10 training examples and train until the loss goes to near zero. If it can’t overfit 10 examples, the model is broken (architecture bug, gradient bug, etc.).

  3. Are gradients flowing? Check grad_fn on your loss tensor. Print gradient norms after .backward(). If gradients are zero or NaN, you have a vanishing/exploding gradient problem.

  4. Is the loss decreasing at all? If the loss is completely flat, the learning rate may be too low or too high, or there’s a bug in the optimizer setup.

  5. Is there a train/val gap? Large train/val gap → overfitting → more regularization. No train/val gap but high loss on both → underfitting → larger model, more data, less regularization.

Mixed Precision Training

For large models: use torch.cuda.amp.autocast() for automatic mixed precision (AMP). This stores activations in float16 during the forward pass and converts back to float32 for weight updates. Typically 2x speedup and 2x memory reduction.

Gotcha: some operations are numerically unstable in float16. The gradient scaler (torch.cuda.amp.GradScaler) handles this by scaling the loss before backprop and unscaling before the optimizer step.

The Cycle

Training is iterative. The discipline:

  1. Train baseline
  2. Identify the failure mode (is it underfitting? overfitting? wrong examples? class imbalance?)
  3. Fix the most impactful issue
  4. Measure the delta
  5. Repeat

Don’t change multiple things at once. One variable at a time, so you know what caused the improvement.

The goal is not a perfect model — it’s a model that improves reliably in the direction of your objective, with an honest evaluation on held-out data.

deep-learning neural-networks training optimization pytorch
← All articles

Lets collaborate!

Whether you need a quantitative researcher, an machine learning systems builder, or a technical advisor — I'm available for select consulting engagements.

Get in Touch →