PyTorch Fully Sharded Data Parallel (FSDP) is used to speed-up model training time by parallelizing training data as well as sharding model parameters, optimizer states, and gradients across multiple pytorch instances.
If your model does not fit on a single GPU, you can use FSDP and request more GPUs to reduce the memory footprint for each GPU. The model parameters are split between the GPUs and each training process receives a different subset of training data. Model updates from each device are broadcast across devices, resulting in the same model on all devices.
For a complete overview with examples, see https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
For running FSDP at OSC, we recommend using a base PyTorch environment or cloning a base PyTorch environment and adding your project’s specific packages to it.
There are 6 main differences between FSDP and single machine runs:
def fsdp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
class Trainer:
def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_dataset, test_dataset=None):
...
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=mixed_precision_policy,
sharding_strategy=fsdp_config.sharding_strategy,
device_id=torch.cuda.current_device(),
limit_all_gathers=fsdp_config.limit_all_gathers)
from torch.utils.data.distributed import DistributedSampler
sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
train_kwargs = {'batch_size': train_config.batch_size_training, 'sampler': sampler1}
cuda_kwargs = {'num_workers': train_config.num_workers_dataloader,
'pin_memory': True,
'shuffle': False}
train_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
def cleanup():
dist.destroy_process_group()
...
if fsdp_config.fsdp_activation_checkpointing and local_rank == 0: policies.apply_fsdp_checkpointing(model)
class Trainer:
def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_dataset, test_dataset=None):
self.config = trainer_config
# set torchrun variables
self.local_rank = int(os.environ["LOCAL_RANK"])
self.global_rank = int(os.environ["RANK"])
...
#!/bin/bash
#SBATCH --job-name=fsdp-t5-multinode
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-task=4
#SBATCH --cpus-per-task=96
nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
echo Node IP: $head_node_ip
export LOGLEVEL=INFO
ml miniconda3/24.1.2-py310
conda activate fsdp
srun torchrun \
--nnodes 2 \
--nproc_per_node 1 \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:29500 \
/path/to/examples/distributed/T5-fsdp/fsdp_t5.py