PyTorch Distributed Data Parallel (DDP) is used to speed-up model training time by parallelizing training data across multiple identical model instances.
If your model fits on a single GPU and you have a large training set that is taking a long time to train, you can use DDP and request more GPUs to increase training speed. The entire model is duplicated on each GPU and each training process receives a different subset of training data. Model updates from each device are broadcast across devices, resulting in the same model on all devices.
For a complete overview with video tutorial and examples, see https://pytorch.org/tutorials/beginner/ddp_series_intro.html
Environment Setup
For running DDP at OSC, we recommend using a base PyTorch environment or cloning a base PyTorch environment and adding your project’s specific packages to it.
There are 6 main differences between DDP and single machine runs. The following code examples are taken from https://github.com/pytorch/examples/tree/main/distributed/minGPT-ddp:
DDP Setup Function
DDP setup creates a process group and sets the local device. This function is called toward the start of main.
def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
Trainer wraps model in DDP
from torch.nn.parallel import DistributedDataParallel as DDP
class Trainer:
def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_dataset, test_dataset=None):
...
self.model = DDP(self.model, device_ids=[self.local_rank])
Use DistributedSampler to load data (and set shuffle=False)
from torch.utils.data.distributed import DistributedSampler
class Trainer:
...
def _prepare_dataloader(self, dataset: Dataset):
return DataLoader(
dataset,
batch_size=self.config.batch_size,
pin_memory=True,
shuffle=False,
num_workers=self.config.data_loader_workers,
sampler=DistributedSampler(dataset)
)
Destroy process group when done
def main():
...
trainer.train()
destroy_process_group()
Only save checkpoints where local_rank=0
class Trainer:
...
def train(self):
for epoch in range(self.epochs_run, self.config.max_epochs):
epoch += 1
self._run_epoch(epoch, self.train_loader, train=True)
if self.local_rank == 0 and epoch % self.save_every == 0:
self._save_snapshot(epoch)
Global vs local rank tracked separately
class Trainer:
def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_dataset, test_dataset=None):
self.config = trainer_config
# set torchrun variables
self.local_rank = int(os.environ["LOCAL_RANK"])
self.global_rank = int(os.environ["RANK"])
...
Example Slurm Job Script using Srun Torchrun
#!/bin/bash
#SBATCH --job-name=multinode-example-minGPT
#SBATCH --nodes=2
#SBATCH --ntasks=2
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=4
nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
echo Node IP: $head_node_ip
export LOGLEVEL=INFO
ml miniconda3/24.1.2-py310
conda activate minGPT-ddp
srun torchrun \
--nnodes 2 \
--nproc_per_node 1 \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:29500 \
/path/to/examples/distributed/minGPT-ddp/mingpt/main.py