2517 words
13 minutes
Scale Up: Distributed Training in PyTorch

Scale Up: Distributed Training in PyTorch#

Welcome to this comprehensive guide on distributed training in PyTorch. In this post, we’ll explore how to scale up your deep learning workloads from a single GPU on your workstation to multiple GPUs across multiple machines. The goal is to gradually build your knowledge, starting with the core concepts of distributed training and culminating in professional-level expansions and best practices. Along the way, we’ll weave in code examples, tables, and practical tips to help you become adept at training models faster and more efficiently in a distributed manner.


Table of Contents#

  1. Introduction
  2. Why Distributed Training?
  3. Foundations of PyTorch Distributed
  4. Essential Terminology and Concepts
  5. Single-Node, Multi-GPU: DataParallel in PyTorch
  6. DistributedDataParallel: Scaling Beyond One Machine
  7. Communication Backends and Collectives
  8. Launching Distributed PyTorch Jobs
  9. Multi-Node Training with Examples
  10. Understanding and Using Process Groups
  11. Advanced Distributed Techniques
  12. Performance Tips and Debugging Distributed Training
  13. A Quick Comparison of Distributed Approaches
  14. Real-World Use Cases
  15. Conclusion and Further Reading

Introduction#

Deep learning models and datasets continue to grow in complexity and size. At a certain point, training on a single GPU—or even a single machine—becomes too slow to be practical. When massive datasets and architectures like BERT, GPT, or large-scale image recognition models are on the docket, distributed training is often the only viable method to train these models in a reasonable timeframe.

PyTorch has become one of the most popular deep learning frameworks, prized for its flexibility, ecosystem, vibrant community, and alignment with Pythonic programming paradigms. As your projects scale, understanding PyTorch’s distributed capabilities is critical for accelerating training time, enabling larger experiments, and potentially saving on costs by utilizing resources efficiently.

In this blog post, we’ll start by explaining why distributed training is valuable. Then, we’ll dive into how PyTorch handles parallelization. We’ll examine different distributed paradigms, from the simplest (DataParallel) to advanced frameworks and best practices. By the end, you’ll have a strong grasp of how to scale your workloads while maintaining correctness and reproducibility.


Why Distributed Training?#

Before delving into technical details, let’s clarify the “why” behind distributed training:

  1. Acceleration: Splitting workload across multiple GPUs or nodes can dramatically cut down training time. For large models, it may be impossible to train effectively on a single machine or GPU.
  2. Larger Models: As models grow, fitting your network (or even a single batch) into GPU memory on one machine might be infeasible. Distributing the computation can help accommodate larger models.
  3. Bigger Datasets: Datasets (e.g., large-scale image classification, massive text corpora) may be too big to process quickly with a single GPU. Distributed training speeds up data loading and processing.
  4. Resource Utilization: Many organizations have multiple GPUs spread across clusters. Distributed training helps leverage these resources efficiently, saving time and possibly money on cloud-based GPUs.

Foundations of PyTorch Distributed#

PyTorch’s distributed API is built around the concept of processes that communicate with each other to collectively train a model. Each process typically operates on a subset of the data (in data-parallel approaches) or a subset of the model’s parameters (in model-parallel or pipeline-parallel approaches).

At a high level, distributed PyTorch includes:

  • Initialization: Processes must form a distributed group, specifying the communication backend (e.g., NCCL, Gloo, MPI) and a unique rank for each process.
  • Distribution Strategy: Different ways of distributing your model’s parameters and your data across processes.
  • Synchronization: Gradients (or parameters) need to be communicated across processes in a synchronized manner.

Essential Terminology and Concepts#

Before getting hands-on, let’s define some key terms you’ll encounter:

TermDefinition
RankA unique identifier for each process in a distributed job. For example, in a job with eight processes, ranks go from 0 to 7.
World SizeThe total number of processes that participate in the distributed training job.
Master (or Chief) ProcessOften rank 0 is called the master or chief process, responsible for tasks like logging outputs or orchestrating.
BackendThe library/method used for communication. Common backends: NCCL (recommended for GPU-based training), Gloo (CPU-based, multi-machine as well), MPI.
Process GroupA group of processes that can communicate with each other collectively using PyTorch’s APIs.

Communication Collectives#

Collective operations are crucial in distributed computing. They perform operations like:

  • broadcast
  • all_reduce
  • gather and scatter
  • all_gather
  • reduce_scatter

For example, an all_reduce sums a tensor from all processes and distributes the result back to every process. This is fundamental for synchronizing gradients in data-parallel training.

Data-Parallel vs. Model-Parallel#

  1. Data-Parallel: Each process holds the same model but works on a subset of the data. After computing gradients, processes synchronize their model parameters (or gradients) so they remain consistent.
  2. Model-Parallel: The model itself is split across multiple GPUs or nodes (e.g., some layers on GPU 0 and others on GPU 1). Model-parallelism is often used when a model is too large to fit into a single GPU memory space.

Single-Node, Multi-GPU: DataParallel in PyTorch#

Although strictly not part of the torch.distributed package, torch.nn.DataParallel is the easiest place to start. It is typically used on a single machine with multiple GPUs to parallelize data.

import torch
import torch.nn as nn
import torch.optim as optim
# Assume we have a model and data
class SimpleModel(nn.Module):
def __init__(self, input_size, output_size):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
return self.linear(x)
# Instantiate the model
model = SimpleModel(100, 10)
model = nn.DataParallel(model) # Wrap model with DataParallel
# Move to GPU if available
if torch.cuda.is_available():
model.cuda()
# Example forward pass
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
# dummy_input: batch_size of 32, input features 100
dummy_input = torch.randn(32, 100)
dummy_labels = torch.randint(0, 10, (32,))
if torch.cuda.is_available():
dummy_input = dummy_input.cuda()
dummy_labels = dummy_labels.cuda()
# Forward + backward
outputs = model(dummy_input)
loss = criterion(outputs, dummy_labels)
loss.backward()
optimizer.step()

How it works:

  • DataParallel replicates the model to all available GPUs on the same machine.
  • It splits the input batch among the GPUs.
  • It collects the partial outputs and gradients, then updates the original model on the default GPU.

Despite its simplicity, DataParallel can become a bottleneck on the main GPU or CPU because it manages gradient aggregation centrally. This is why for multi-GPU, multi-node scenarios, the DistributedDataParallel (DDP) module is generally preferred.


DistributedDataParallel: Scaling Beyond One Machine#

torch.nn.parallel.DistributedDataParallel (DDP) is the recommended way to do data-parallel training across multiple processes. Each process typically sits on a single GPU (though you can run multiple processes per machine if you have multiple GPUs). Gradients are synchronized in a peer-to-peer fashion, which is more efficient than the single-threaded approach of DataParallel.

Key Steps for Using DDP#

  1. Initialization: Initialize the distributed process group (backend, rank, world size).
  2. Model Wrapping: Create the model and wrap it with DistributedDataParallel.
  3. Data Loading: Use DistributedSampler to ensure each process sees a unique subset of the dataset.
  4. Train: Each process handles its subset of data. Gradients are synchronized automatically by DDP after backward().

Below is a simplified code snippet for single-node, multi-GPU training with DDP. Suppose you have 4 GPUs (IDs [0, 3]), and you will launch 4 processes where each process uses a single GPU.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
def train(rank, world_size):
# Define your backend. For GPUs, NCCL is recommended.
dist.init_process_group("nccl", rank=rank, world_size=world_size)
# Create model, move it to the corresponding GPU
model = torch.nn.Linear(10, 1).to(rank)
ddp_model = DDP(model, device_ids=[rank], output_device=rank)
# Dummy dataset
x = torch.randn(100, 10)
y = torch.randn(100, 1)
dataset = TensorDataset(x, y)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
# Loss and optimizer
criterion = torch.nn.MSELoss().to(rank)
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01)
# Training loop
for epoch in range(2):
sampler.set_epoch(epoch)
for batch_x, batch_y in dataloader:
batch_x, batch_y = batch_x.to(rank), batch_y.to(rank)
optimizer.zero_grad()
pred = ddp_model(batch_x)
loss = criterion(pred, batch_y)
loss.backward()
optimizer.step()
if rank == 0:
print(f"Epoch {epoch}, Loss: {loss.item()}")
# Clean up
dist.destroy_process_group()
if __name__ == "__main__":
world_size = 4
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

Why Use DDP Over DataParallel?#

  • Performance: Reduced overhead by using peer-to-peer communications (NCCL) rather than funneling gradients through a single GPU.
  • Scalability: Multi-node training support.
  • Flexibility: Fine-grained control over how processes connect, how data is sampled, etc.

Communication Backends and Collectives#

In PyTorch, the major communication backends are:

  • NCCL: Primarily used for GPU-based collective communication. Developed by NVIDIA. Excellent performance on multi-GPU systems.
  • Gloo: CPU-based collective communication library. Works across machines and is often used for CPU-based training or tasks requiring low-latency.
  • MPI: Supports MPI-based collective communication. Requires an MPI environment.

Use NCCL when you have NVIDIA GPUs. If your environment does not have GPUs, Gloo is a reliable default.

Common Collectives:

  • dist.all_reduce(tensor): Ex: sum the tensors from all processes, distribute the sum to each process.
  • dist.broadcast(tensor, src=0): Send a tensor from the src to all processes.
  • dist.reduce(tensor, dst=0): Aggregate from all processes and store in the tensor on the dst rank.
  • dist.all_gather(tensor_list, tensor): Gather tensors from all ranks and concatenate them in the list.

Launching Distributed PyTorch Jobs#

When training in a distributed manner, we create one process per GPU (often called “single-process, single-GPU” usage). PyTorch offers multiple ways to launch distributed jobs:

  1. torch.multiprocessing.spawn: Pythonic approach to spawning processes within the same script. Commonly used for single-node.
  2. torchrun (Previously python -m torch.distributed.launch): Command-line approach to launching multi-process training across multiple nodes.
  3. Custom HPC Scripts: On HPC clusters, you might have your own resource manager (e.g., Slurm job scripts).

Using torchrun#

torchrun is the recommended way for most scenarios. Here’s a minimal example for a single node with 2 GPUs:

Terminal window
torchrun \
--nnodes=1 \
--nproc_per_node=2 \
train.py

For multi-node, you would specify:

Terminal window
torchrun \
--nnodes=2 \
--node_rank=0 \
--master_addr="192.168.0.10" \
--master_port=12345 \
--nproc_per_node=2 \
train.py

And on the second node:

Terminal window
torchrun \
--nnodes=2 \
--node_rank=1 \
--master_addr="192.168.0.10" \
--master_port=12345 \
--nproc_per_node=2 \
train.py

Multi-Node Training with Examples#

Distributed training with DistributedDataParallel just scales horizontally to more nodes. Each node spawns multiple GPU processes. The main change from a single-node scenario is specifying the environment variables (or arguments to torchrun) so that each node knows how to connect to the others.

Importantly, each process has a globally unique rank. If you have 2 nodes each with 2 GPUs, your world size is 4. Ranks might be assigned as follows:

  • Node 0: rank 0, rank 1
  • Node 1: rank 2, rank 3

All processes collectively form a single distributed group. Data partitioning is typically managed by DistributedSampler or a custom sampler that ensures non-overlapping data subsets.


Understanding and Using Process Groups#

Under the hood, PyTorch’s distributed package lumps processes into a default group. However, you can create subgroups if you have more advanced control needs (for instance, separate groups of processes dedicated to different tasks).

Example of creating a subgroup:

import torch.distributed as dist
# Suppose the default process group has world_size=4
# Create a subgroup of ranks [0,1]
group = dist.new_group(ranks=[0, 1])
if dist.get_rank() in [0, 1]:
# Only ranks 0 and 1 will run this
tensor = torch.tensor([dist.get_rank()])
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)

This approach can be handy for multi-stage models, pipeline parallelism, or specialized tasks within your distributed job.


Advanced Distributed Techniques#

Beyond vanilla data-parallel training with DDP, PyTorch offers a range of advanced or specialized options:

  1. Model Parallelism: If your model is too large for a single GPU, you can split the model layers across multiple GPUs. This can be done manually or with higher-level frameworks (e.g., Megatron-LM for large language models).
  2. Pipeline Parallelism: Partition your model by layer or stage, so different GPUs handle different segments of the model pipeline. Particularly relevant for extremely large, multi-stage models.
  3. Fully Sharded Data Parallel (FSDP): A technique that shards your model parameters across data-parallel workers, lowering memory usage when training massive models.
  4. RPC-Based Frameworks: With the “Remote Procedure Call” (RPC) framework in PyTorch, you can build more flexible distributed systems. For instance, parameter servers or specialized distributed pipelines.
  5. Elastic Training: PyTorch Elastic allows you to dynamically add or remove nodes/workers from a distributed job. This is helpful in cloud spot instances scenarios.

Example: Using FSDP#

FSDP (Fully Sharded Data Parallel) is an advanced approach that shards parameters (and optionally gradients, optimizer states) across multiple GPUs to reduce memory usage. A minimal snippet might look like this:

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import wrap
def fsdp_main(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
model = torch.nn.Linear(1024, 1024).cuda(rank)
# Wrap your model
model = wrap(model)
fsdp_model = FSDP(model)
# Training loop ...
dist.destroy_process_group()

When using FSDP, keep in mind that the overhead of parameter sharding and re-sharding can be non-trivial, but for large models that exceed single-GPU memory budgets, it can be indispensable.


Performance Tips and Debugging Distributed Training#

  1. Locality: Place each process on the GPU corresponding to its rank. E.g., rank 0 uses GPU 0.
  2. Batch Size: Increase batch size to ensure GPUs are well utilized, but watch out for hardware constraints and potential generalization issues.
  3. Network Bandwidth: For multi-node training, ensure high-speed interconnects (e.g., InfiniBand or high-speed Ethernet).
  4. Gradient Accumulation: In some cases, especially with large batch sizes, you might do gradient accumulation to reduce communication overhead.
  5. Profiling and Debugging: Tools like PyTorch’s profiler or Nvidia Nsight can help pinpoint bottlenecks.
  6. Logging: Each rank might log separately. Tools like TensorBoard can aggregate logs.
  7. Set Epoch Seeds: For reproducibility, ensure that each worker’s seed is set consistently (possibly via sampler.set_epoch(epoch) for DistributedSampler).
  8. Sync Errors: If you see issues that only occur at scale, use debugging techniques such as dist.barrier() to isolate which step might be hanging or failing.

Common Pitfalls#

  • Not using DistributedSampler: This can lead to overlapping data between processes, reducing effective batch size, or messing up the convergence.
  • Forgetting sampler.set_epoch(epoch): Without setting the epoch, the random ordering might not shift properly, leading to less thorough data shuffling each epoch.
  • Rank mismatch: Confusion over environment variables or node_rank can cause processes to form incomplete groups or fail to synchronize.

A Quick Comparison of Distributed Approaches#

Below is a simplified table summarizing the main data parallel approaches offered by PyTorch:

ApproachScopeComplexityTypical Use Case
DataParallel (DP)Single-node, multi-GPULowQuick prototyping on a single machine with multiple GPUs
DistributedDataParallel (DDP)Multi-node, multi-GPUMediumProduction-grade training on clusters
Fully Sharded Data Parallel (FSDP)Multi-node, multi-GPUHighExtremely large models requiring memory sharding
Model Parallel / Pipeline ParallelMulti-node, multi-GPUHighVery large models that do not fit into GPU memory with data parallel alone

As your needs become more specialized, you can combine these approaches (e.g., data parallel + pipeline parallel) for extremely large-scale training.


Real-World Use Cases#

  1. Language Modeling (GPT, BERT): Huge models with billions of parameters. Commonly use model parallelism or mixed data and model parallel approaches (e.g., Megatron-LM, DeepSpeed).
  2. CV and Medical Imaging: Large 3D scans can require multiple GPUs to handle big volumes. DDP is often enough here for multi-GPU training.
  3. Reinforcement Learning: Where many parallel workers (possibly thousands) generate environments or episodes. Put the training step in a distributed or parallel environment.
  4. RecSys: Large embedding tables for user/item interactions. Sharded embedding tables across multiple nodes is common.
  5. Industrial HPC: Simulations for finance, physics, or engineering, often integrated with HPC clusters. PyTorch’s distributed support integrates well with typical HPC setups (Slurm, MPI).

Conclusion and Further Reading#

We’ve covered a wide range of topics from simple single-machine parallelism (DataParallel) to advanced multi-node training with DistributedDataParallel and beyond. Properly implementing distributed training can help you scale your models and experiments efficiently, whether you have a local workstation with a few GPUs or a cluster with dozens of machines.

Further Reading and Resources#

Distributed training is a powerful tool that can cut your training time from weeks to days, or even hours—all while enabling models bigger than previously possible on a single GPU. Embrace the distributed features of PyTorch to future-proof your deep learning endeavors, experiment at scale, and make the most of modern GPU and HPC resources.


Happy scaling and training! Use these techniques responsibly, experiment with small prototypes before going full-scale, and rest assured that PyTorch’s distributed ecosystem has your back as you delve deeper into the world of large-scale deep learning.

Scale Up: Distributed Training in PyTorch
https://science-ai-hub.vercel.app/posts/d44182a6-ad55-49ac-b2f2-ecff38fb6451/14/
Author
AICore
Published at
2025-03-03
License
CC BY-NC-SA 4.0