Overview
Out-of-Memory (OOM) errors during artificial neural network (ANN) training are common and can slow down the process of obtaining desired experimental results. A number of strategies exist to overcome this challenge, including requesting more resources and distributed training, using smaller models and data precision, setting hyperparameters, and other techniques. While this is not an exhaustive guide, the following recommendations are meant to reduce GPU memory usage and reduce time to get results. If you require assistance, please contact OSC Support.
Also, consider profiling your GPU memory usage to identify which portions of your training code are using the most memory, allowing you to target your strategies accordingly.
Requesting and Using More Resources
- Request more GPU resources
- Use Fully Sharded Data Parallel (FSDP) in PyTorch to distribute large models across multiple GPU devices
Using Smaller Models and Datatypes
Model choice has the single largest impact on GPU memory usage, so it's important to choose a model that suits your needs but is not unnecessarily large if there's no advantage for your use case. See how to estimate GPU memory usage based on model size in billions of parameters. Each parameter's datatype also strongly affects total model size - consider using lower precision datatypes if feasible to reduce memory footprint. Lower precision calculations can also be faster and use less energy. Using a smaller model or lower precision datatype may negatively impact model fit and overall performance however, and individual needs vary in terms of flexibility with model choice.
- Use smaller models to reduce overall parameter count
- Use lower precision datatypes to reduce bytes per parameter
- Enable mixed precision training - Mixed precision trainings uses both 32-bit and 16-bit representations at different times during training and can reduce memory usage and training time.
- Enable bfloat16 training - Pure 16-bit training does not use fp32 floating point integers during training for increased speed and reduced memory usage, at the cost of potential model fit.
- Quantization - Even lower precision datatypes such as int8 can be used in quantized training, where the datatype is actually cast to a lower bit width.
- Some quantization-related techniques are unlikely to reduce memory usage during training:
- Quantization-Aware Training (QAT) is an alternative to true quantized training, and it simulates lower-precision datatypes alongside the higher precision representations. Therefore, the memory usage during training with QAT can actually increase; its benefits are more oriented toward reduced cost during inference.
- Post training static and dynamic quantization involves converting model weights after the training is completed - again the main benefit is inference cost.
- Some quantization-related techniques are unlikely to reduce memory usage during training:
Setting Hyperparameters
Setting hyperparameters can have a large impact on reducing memory usage during ANN training. Reducing batch size and context length in particular can result in a sizable reduction in memory usage. Another benefit of adjusting hyperparameters is that little to no code changes are required, making it easy to experiment with different values.
- Reduce batch size - training and validation batches can have their own batch size hyperparameters. Batch size reduction will have increased memory reduction as the size of each training instance increases. For example, text encoding with long contexts or large images, audio or video clips see the largest memory reductions as batch size decreases.
- Reduce context length - this may be called max_seq_len, context_len, or other name if you're using a pre-built model.
- Set Dataloaders to Number of GPUs -To avoid I/O bottleneck, aim to set your number of dataloaders at least equal to your number of GPUs.
Other GPU Memory-Reduction Techniques
- Use flash-attention (not available on V100s)
- Gradient Accumulation - increases effective batch size without increasing memory, minimal code changes
- Activation Checkpointing - recompute activations, trading off extra computation for lower memory usage
- For validation runs during training, ensure gradient computations are disabled
torch.no_grad()
model.eval()
GPU Memory-Reduction During Inference (this section in progress)
Disable gradient computations - eliminate memory for unncessary calculations
Enable Paged Attention - kv cache memory reduction
Enable Eager Mode (vllm)
Reduce GPU utilization (vllm)
Reduce context length - may be called different things with different services
Reduce batch size - may be called different things with different services