PyTorch Fully Sharded Data Parallel (FSDP) is used to speed-up model training time by parallelizing training data as well as sharding model parameters, optimizer states, and gradients across multiple pytorch instances. The current version is FSDP2, which is not backwards-compatible with the original FSDP, which is deprecated as of PyTorch 2.11.0+.
If your model does not fit on a single GPU, you can use FSDP and request more GPUs to reduce the memory footprint for each GPU. The model parameters are split between the GPUs and each training process receives a different subset of training data. Model updates from each device are broadcast across devices, resulting in the same model on all devices.
For a complete overview with examples, see https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
Environment Setup
For running FSDP at OSC, we recommend using a base PyTorch environment or cloning a base PyTorch environment and adding your project’s specific packages to it.
There are 5 main differences between FSDP and single GPU runs. See https://github.com/pytorch/examples/tree/main/distributed/FSDP2 for a detailed code example. Note however that the example supports multigpu but not multinode. For changes required to support multinode training, scroll down.
1. FSDP Imports
from torch.distributed.fsdp import fully_shard
2. Model Initialization
Apply fully_shard on each layer first, and then on the full model.
model = Transformer()
for layer in model.layers:
fully_shard(layer)
fully_shard(model)
3. FSDP Process Group Initialization
This is called toward the start of main.
#set device
rank = int(os.environ["LOCAL_RANK"])
if torch.accelerator.is_available():
device_type = torch.accelerator.current_accelerator()
device = torch.device(f"{device_type}:{rank}")
torch.accelerator.device_index(rank)
#initialize process group
backend = torch.distributed.get_default_backend_for_device(device)
torch.distributed.init_process_group(backend=backend, device_id=device)
4. Define and set prefetching
def set_modules_to_forward_prefetch(model, num_to_forward_prefetch):
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_forward_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
def set_modules_to_backward_prefetch(model, num_to_backward_prefetch):
for i, layer in enumerate(model.layers):
if i < num_to_backward_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
#<model setup>
...
if args.explicit_prefetching:
set_modules_to_forward_prefetch(model, num_to_forward_prefetch=2)
set_modules_to_backward_prefetch(model, num_to_backward_prefetch=2)
5. Destroy process group after training/validation and any post-processing has completed. I.e., outside of training loop.
torch.distributed.destroy_process_group()
Multinode Changes
There are also 3 changes for using multinode, rather than just multigpu. See https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multinode.py
1. Use DistributedSampler to load data (required)
from torch.utils.data.distributed import DistributedSampler
sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
train_kwargs = {'batch_size': train_config.batch_size_training, 'sampler': sampler1}
cuda_kwargs = {'num_workers': train_config.num_workers_dataloader, 'pin_memory': True, 'shuffle': False} train_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
2. Global vs local rank tracked separately (optional, but recommended)
class Trainer: def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_dataset, test_dataset=None): self.config = trainer_config # set torchrun variables self.local_rank = int(os.environ["LOCAL_RANK"]) #use local rank to load model to device self.global_rank = int(os.environ["RANK"]) ...
3. Only save checkpoints where global_rank=0 (optional, but recommended)
if self.global_rank == 0:
...
torch.save(model_state_dict, new_model_checkpoint)
torch.save(optim_state_dict, new_optim_checkpoint)
Example Slurm Job Script using Srun Torchrun
#!/bin/bash
#SBATCH --job-name=fsdp-t5-multinode
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-task=4
#SBATCH --cpus-per-task=96
nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
echo Node IP: $head_node_ip
export LOGLEVEL=INFO
ml miniconda3/24.1.2-py310
conda activate fsdp
srun torchrun \
--nnodes 2 \
--nproc_per_node 1 \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:29500 \
/path/to/examples/distributed/T5-fsdp/fsdp_t5.py