HOWTO: PyTorch Fully Sharded Data Parallel (FSDP2)

PyTorch Fully Sharded Data Parallel (FSDP) is used to speed-up model training time by parallelizing training data as well as sharding model parameters, optimizer states, and gradients across multiple pytorch instances.  The current version is FSDP2, which is not backwards-compatible with the original FSDP, which is deprecated as of PyTorch 2.11.0+.

 

If your model does not fit on a single GPU, you can use FSDP and request more GPUs to reduce the memory footprint for each GPU.  The model parameters are split between the GPUs and each training process receives a different subset of training data.  Model updates from each device are broadcast across devices, resulting in the same model on all devices.

 

For a complete overview with examples, see https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html

 

Environment Setup

For running FSDP at OSC, we recommend using a base PyTorch environment or cloning a base PyTorch environment and adding your project’s specific packages to it.

 

There are 5 main differences between FSDP and single GPU runs.  See https://github.com/pytorch/examples/tree/main/distributed/FSDP2 for a detailed code example. Note however that the example supports multigpu but not multinode.  For changes required to support multinode training, scroll down.

1. FSDP Imports

from torch.distributed.fsdp import fully_shard

2. Model Initialization

Apply fully_shard on each layer first, and then on the full model.

model = Transformer()
for layer in model.layers:
    fully_shard(layer)
fully_shard(model)

 3. FSDP Process Group Initialization

This is called toward the start of main.

#set device
rank = int(os.environ["LOCAL_RANK"])
if torch.accelerator.is_available():
    device_type = torch.accelerator.current_accelerator()
    device = torch.device(f"{device_type}:{rank}")
    torch.accelerator.device_index(rank)

#initialize process group
backend = torch.distributed.get_default_backend_for_device(device)
torch.distributed.init_process_group(backend=backend, device_id=device)

4. Define and set prefetching

def set_modules_to_forward_prefetch(model, num_to_forward_prefetch):
    for i, layer in enumerate(model.layers):
        if i >= len(model.layers) - num_to_forward_prefetch:
            break
        layers_to_prefetch = [
            model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
        ]
        layer.set_modules_to_forward_prefetch(layers_to_prefetch)

def set_modules_to_backward_prefetch(model, num_to_backward_prefetch):
    for i, layer in enumerate(model.layers):
        if i < num_to_backward_prefetch:
            continue
        layers_to_prefetch = [
            model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
        ]
        layer.set_modules_to_backward_prefetch(layers_to_prefetch)

#<model setup>
...
if args.explicit_prefetching:
    set_modules_to_forward_prefetch(model, num_to_forward_prefetch=2)
    set_modules_to_backward_prefetch(model, num_to_backward_prefetch=2)

5. Destroy process group after training/validation and any post-processing has completed.  I.e., outside of training loop.

torch.distributed.destroy_process_group()

 

Multinode Changes

There are also 3 changes for using multinode, rather than just multigpu.  See https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multinode.py

1. Use DistributedSampler to load data (required)

from torch.utils.data.distributed import DistributedSampler 

sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True) 
train_kwargs = {'batch_size': train_config.batch_size_training, 'sampler': sampler1} 
cuda_kwargs = {'num_workers': train_config.num_workers_dataloader, 'pin_memory': True, 'shuffle': False} train_kwargs.update(cuda_kwargs) 
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)

2. Global vs local rank tracked separately (optional, but recommended)

class Trainer: 
    def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_dataset, test_dataset=None): 
        self.config = trainer_config # set torchrun variables 
        self.local_rank = int(os.environ["LOCAL_RANK"]) #use local rank to load model to device 
        self.global_rank = int(os.environ["RANK"]) 
        ...

3. Only save checkpoints where global_rank=0 (optional, but recommended)

if self.global_rank == 0:

    ...

    torch.save(model_state_dict, new_model_checkpoint)

    torch.save(optim_state_dict, new_optim_checkpoint)

 

 

Example Slurm Job Script using Srun Torchrun

#!/bin/bash
#SBATCH --job-name=fsdp-t5-multinode
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-task=4
#SBATCH --cpus-per-task=96

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

echo Node IP: $head_node_ip
export LOGLEVEL=INFO

ml miniconda3/24.1.2-py310
conda activate fsdp

srun torchrun \
--nnodes 2 \
--nproc_per_node 1 \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:29500 \
/path/to/examples/distributed/T5-fsdp/fsdp_t5.py