Why Distribution Strategies Matter
A single GPU has a fixed memory ceiling and compute budget. When model training exceeds these limits — whether due to dataset size, model size, or time constraints — distributed training becomes necessary. TensorFlow’s tf.distribute API provides first-class abstractions for distributing work across GPUs and machines without rewriting training code.
Three strategies cover most use cases: MirroredStrategy for single-machine multi-GPU, MultiWorkerMirroredStrategy for multi-machine multi-GPU, and ParameterServerStrategy for asynchronous distributed training.
MirroredStrategy: Synchronous Multi-GPU on One Machine
How It Works
MirroredStrategy trains on multiple GPUs on a single machine using synchronous data parallelism:
- Initialization: model variables are created and replicated on every GPU — each GPU holds an identical copy of the model
- Forward pass: each GPU receives a different shard of the batch and computes its forward pass independently, in parallel
- Backward pass: each GPU computes gradients for its shard independently
- AllReduce: gradients are aggregated across all GPUs using all-reduce algorithms (NCCL on NVIDIA GPUs), producing a single combined gradient
- Update: all GPUs update their local copy with the combined gradient — they remain identical after each step
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = build_model()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
# Batch size is automatically divided across GPUs
model.fit(dataset, epochs=10)
The effective batch size equals batch_size_per_gpu × num_gpus. Adjust the learning rate accordingly (linear scaling rule: multiply LR by the number of GPUs when scaling batch size).
Benefits and Limitations
Benefits:
- Simple API — wrap model creation in
strategy.scope(), everything else is automatic - Fastest convergence for typical model sizes: all GPUs make identical, globally consistent updates
- Higher GPU utilization: all GPUs are active simultaneously during forward and backward passes
- Fault tolerance: if one GPU fails, training fails cleanly (no partial updates)
Limitations:
- Single machine: limited to the number of GPUs on one node
- Communication overhead: AllReduce scales with parameter count — for very large models (billions of parameters), gradient communication dominates compute time
- Memory: each GPU must fit the full model
When AllReduce Becomes a Bottleneck
The threshold depends on hardware, but as a rule: when training time starts being dominated by inter-GPU communication rather than compute, consider ParameterServerStrategy or model parallelism. On modern NVLink interconnects (A100, H100), this threshold is much higher than on PCIe connections.
ParameterServerStrategy: Asynchronous Multi-Worker Training
How It Works
ParameterServerStrategy uses a parameter server architecture: a set of dedicated servers hold model variables, while worker nodes perform computation.
- Initialization: model variables are placed on parameter servers; each worker receives a copy
- Forward pass: each worker processes its own data shard using its local parameters
- Backward pass: each worker computes gradients independently
- Update: workers push their gradients to the parameter servers asynchronously; parameter servers apply updates immediately without waiting for other workers
- Distribute: workers pull the updated parameters from parameter servers before the next step
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
strategy = tf.distribute.ParameterServerStrategy(cluster_resolver)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)
with strategy.scope():
model = build_model()
@tf.function
def per_worker_step(iterator):
return strategy.run(train_step, args=(next(iterator),))
for epoch in range(num_epochs):
coordinator.schedule(per_worker_step, args=(distributed_dataset,))
coordinator.join()
Benefits and Limitations
Benefits:
- Scalability: designed for large numbers of workers across many machines
- Asynchronous throughput: workers do not wait for each other — no stragglers blocking the entire step
- Flexible: works with heterogeneous hardware (workers can have different GPU counts)
Limitations:
- Stale gradients: a slow worker may apply gradients computed from parameters that have already been updated by faster workers — this can slow convergence or cause instability
- Communication overhead: workers continuously push/pull from parameter servers — heavy bidirectional traffic
- Complexity: requires coordinating separate processes for workers and parameter servers
MultiWorkerMirroredStrategy: Synchronous Multi-Machine Multi-GPU
The logical extension of MirroredStrategy to multiple machines. Each machine uses MirroredStrategy internally, and the machines coordinate via AllReduce across machines (using NCCL for GPU-GPU communication over high-speed interconnects, or gRPC for slower networks).
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = build_model()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(dataset, epochs=10)
Environment variable TF_CONFIG must be set on each worker to define the cluster topology.
Best for: training large models on a cluster of GPU machines with high-bandwidth interconnects (InfiniBand, NVLink-C2C). The synchronous updates preserve all the convergence properties of single-GPU training.
Choosing the Right Strategy
| Scenario | Strategy |
|---|---|
| Multiple GPUs, single machine | MirroredStrategy |
| Multiple machines, homogeneous hardware, fast interconnect | MultiWorkerMirroredStrategy |
| Many workers, heterogeneous hardware, or asynchronous updates acceptable | ParameterServerStrategy |
| Model too large for one GPU | Model parallelism (pipeline or tensor) — not covered by these strategies |
MirroredStrategy vs. ParameterServerStrategy: When Is Each Faster?
MirroredStrategy is typically faster when:
- The model fits on each GPU
- AllReduce communication time is small relative to compute time
- Synchronous updates are important for convergence quality
ParameterServerStrategy is typically faster when:
- Workers have highly variable computation time (stragglers)
- The network bandwidth between machines is limited
- The number of workers is very large (AllReduce overhead grows with worker count)
The practical test: run both on your hardware and compare wall-clock time per epoch and final convergence quality. The theoretical analysis does not always match hardware-specific performance.
Production Notes
Fault tolerance: MirroredStrategy has no built-in fault tolerance — one GPU failure kills the run. ParameterServerStrategy handles worker failures gracefully (a worker that fails is restarted and pulls current parameters). For long training runs on cloud hardware with preemption, this makes ParameterServerStrategy more resilient.
Batch size scaling: doubling the number of workers with MirroredStrategy doubles the effective batch size. This typically requires a corresponding increase in learning rate (linear scaling rule) and may require learning rate warmup for stability.
tf.function and distribution: always wrap the train step in @tf.function when using distribution strategies. The function is traced once per replica and compiled into a graph — running in eager mode with distribution strategies is significantly slower.