Training a neural network is deterministic in principle and chaotic in practice. The same architecture, same optimizer, and same data can produce wildly different results depending on initialization, learning rate schedule, batch size, and a dozen other choices. This playbook is the consolidated set of decisions that actually matter.
Before Training: Set Up the Experiment Correctly
Fix randomness. Set seeds for numpy, torch, and Python’s random at the start of every experiment. Unreproducible results make debugging impossible.
Verify your data pipeline first. Load a batch, print shapes, visualize samples. Most “the model isn’t learning” bugs are data pipeline bugs, not model bugs. Check that your labels are aligned with your inputs.
Establish a baseline. Before your full model, train the simplest possible thing that could work: logistic regression, a single linear layer, a 2-layer MLP. If the simple thing doesn’t work, the complex thing definitely won’t.
Initialization
Poor initialization can mean the network spends thousands of steps recovering from a bad starting point, or never recovers at all.
Xavier / Glorot initialization: Designed for tanh / sigmoid activations. Weights sampled from a uniform or normal distribution scaled by 1/sqrt(fan_in).
He / Kaiming initialization: Designed for ReLU activations. Weights sampled from a normal distribution scaled by sqrt(2/fan_in). The sqrt(2) factor compensates for ReLU zeroing half of activations.
Practical rule: Use Kaiming init for ReLU-based networks. It’s the default in PyTorch for Conv and Linear layers — verify this is what you’re getting by checking the layer initialization code.
Zero initialization trap: Biases can be initialized to zero. Weights should not — zero weights make all neurons identical, and gradient updates will never break the symmetry.
The Learning Rate Is the Most Important Hyperparameter
Everything else is secondary to getting the learning rate right.
Too high: The loss bounces around or diverges. You’ll see NaN losses or oscillating validation curves.
Too low: Training converges, but slowly — and may settle in a poor local minimum.
The right range: Start with 1e-3 for Adam, 1e-1 for SGD. These are rough starting points; use a learning rate finder to verify.
Learning rate finder: Train for one epoch, increasing the learning rate exponentially from a very small value (1e-7) to a large value (10). Plot loss vs. learning rate. The optimal learning rate is just before the loss starts increasing sharply.
Learning rate schedules:
- Constant: Simple but almost always suboptimal
- Step decay: Reduce by a factor (0.1) every N epochs
- Cosine annealing: Smooth reduction following a cosine curve — good for final training runs
- One-cycle policy: Ramp up then ramp down, with a very high maximum LR — often the fastest path to a good model
- Warmup + decay: Start low, warm up over the first few epochs (especially important for Transformers), then decay
Batch Size and Its Effects
Larger batch sizes:
- Faster per-epoch training (better GPU utilization)
- Better gradient estimates (less noisy)
- But: often worse generalization (sharper minima, worse on new data)
Smaller batch sizes:
- Noisier gradients (but this noise can help escape sharp local minima)
- Better generalization on many tasks
- Slower per-epoch (more gradient updates, but each is less efficient)
Rule of thumb: Start with batch size 32 or 64. If you scale to larger batches, scale the learning rate proportionally (linear scaling rule: 2x batch size → 2x learning rate) and use a longer warmup.
For Transformers: batch size is often constrained by memory. Use gradient accumulation to simulate larger effective batch sizes.
Optimizers
Adam: The default choice. Adaptive learning rates per parameter. Works well without extensive tuning. Use lr=1e-3, betas=(0.9, 0.999), eps=1e-8.
AdamW: Adam with proper weight decay. Almost always prefer over Adam when using L2 regularization — standard Adam’s weight decay implementation is subtly wrong.
SGD with momentum: Often better than Adam for final model performance on image tasks, but requires more tuning. Requires a learning rate schedule.
Lion: A newer optimizer that can match or beat AdamW with less memory overhead. Worth trying if memory is a constraint.
Practical guidance: Use AdamW for most tasks. If you need the absolute best performance and have time to tune, compare against SGD with cosine annealing.
Regularization
L2 weight decay: Penalizes large weight values. Implemented correctly in AdamW. Typical values: 1e-4 to 1e-2.
Dropout: Randomly zero out activations during training. Forces the network to learn redundant representations. Use after linear layers and attention layers (not usually after convolutions). Rate: 0.1–0.5 depending on network depth and data size.
Batch normalization: Normalizes layer activations to have zero mean and unit variance within a batch. Reduces internal covariate shift, allows higher learning rates, provides implicit regularization. Apply before the activation function.
Layer normalization: Like batch norm but normalizes across the feature dimension rather than the batch dimension. Required for Transformers (batch norm breaks with variable-length sequences).
Data augmentation: Extends the effective training set size by applying random transformations to inputs. Domain-specific — for images: random crops, horizontal flips, color jitter. For time-series: time warping, magnitude scaling, noise injection.
Early stopping: Stop training when validation loss stops improving (with a patience parameter). Simple and effective. Keep the checkpoint from the epoch with best validation loss, not the last epoch.
Debugging a Training Run That Isn’t Working
The debugging process in order:
-
Is the data correct? Print a few batches. Visualize them. Check label distribution. Verify that preprocessing is deterministic.
-
Does the model overfit a small batch? Take 10 training examples and train until the loss goes to near zero. If it can’t overfit 10 examples, the model is broken (architecture bug, gradient bug, etc.).
-
Are gradients flowing? Check
grad_fnon your loss tensor. Print gradient norms after.backward(). If gradients are zero or NaN, you have a vanishing/exploding gradient problem. -
Is the loss decreasing at all? If the loss is completely flat, the learning rate may be too low or too high, or there’s a bug in the optimizer setup.
-
Is there a train/val gap? Large train/val gap → overfitting → more regularization. No train/val gap but high loss on both → underfitting → larger model, more data, less regularization.
Mixed Precision Training
For large models: use torch.cuda.amp.autocast() for automatic mixed precision (AMP). This stores activations in float16 during the forward pass and converts back to float32 for weight updates. Typically 2x speedup and 2x memory reduction.
Gotcha: some operations are numerically unstable in float16. The gradient scaler (torch.cuda.amp.GradScaler) handles this by scaling the loss before backprop and unscaling before the optimizer step.
The Cycle
Training is iterative. The discipline:
- Train baseline
- Identify the failure mode (is it underfitting? overfitting? wrong examples? class imbalance?)
- Fix the most impactful issue
- Measure the delta
- Repeat
Don’t change multiple things at once. One variable at a time, so you know what caused the improvement.
The goal is not a perfect model — it’s a model that improves reliably in the direction of your objective, with an honest evaluation on held-out data.