Random training tips and tricks
List of all sort of things you can use make your training/finetuning go faster. From one-line tricks to whole architecture changes. Order is random and grouping is somewhat arbitrary. Treat this as a knowledge dump.
- buy bigger GPU.
Memory Optimization Techniques
Optimizer Memory Reduction
- 8-bit Adam - 2*8-bit vs 2*32-bit = ~75% less memory for optimizer states
- Not storing activations from forward pass but recomputing them in backward pass (memory-compute tradeoff)
- "Stateless" optimizers - SGD, Lion. In contrast to e.g. Adam these don't store additional momentum variables (which takes additional 2x model size of space). The catch is these are not drop-in replacements and Adam was specially made for faster convergence & more stable training.
- Optimizer states offloading to CPU, using
torchao:optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
optim.load_state_dict(ckpt["optim"])
Parameter-Efficient Methods
-
LoRA (Low-Rank Adaptation) - thin layer over frozen pretrained layers of your model. It uses a trick of decomposing a large matrix into two smaller low-rank ( where ) matrices that gives huge memory savings. Slightly more formal:
, where:
- pretrained weight tensor
- , (init LoRA tensors) where:
- at init
- at init
- at the begining are no-op but thanks to being Gaussian there will be symmetry breaking
- some good links about it: Sebastian Raschka, AI Coffe Break
-
QLoRA - It differs in that base model is quantized (usually to 4-bit) but LoRA layers are kept in 16-bit precision
-
After training its good to merge LoRA layers with the base model for less latency and computational overhead (but you can't swap if you have more of them).
-
DoRA - decomposition of weight matrix into magnitude vector (euclidean distance) & directional matrix (angle) and train them separately.
-
GaLore - LoRA but for gradient matrix. Supposedly works also for pretraining in contrast to LoRA, but I didn't tried it.
-
Prefix Tuning - In place of the prompt you put a random init vector (the so-called prefix) and optimize it until you get the correct answer.
- ✅ tiny amount of parameters to tune
- ❌ takes context length (but alternatively you would put pre-prompt there)
- ❌ interpretability - these are not words, but you can decode it to "nearest" words but it often gibberish
Mixed Precision Strategies
https://sebastianraschka.com/blog/2023/llm-mixed-precision-copy.html
https://medium.com/@jbensnyder/solving-the-limits-of-mixed-precision-training-231019128b4b
All types other than FP32 are only faster on consumer cards due to Tensor cores which are only available on RTX cards. Tensor cores are used automatically when using mixed precision.
-
FP32+FP16/BF16 (speed & memory*) - you store an extra copy of the model in 16-bit to be able to do faster calculations (~2x) (memory-compute trade-off). Gradients are also computed in 16-bit, but stored in FP32. BF16 should be more stable than FP16, as more range is more important for NNs than precision. BF16 stands from "Brain Float" from Google Brain btw.
*memory savings will be only visible if your batch size is sufficiently large and outweight additional cost of storing copy of weights in 16-bits
-
turn on TF32 (Tensor Float) (speed) - not all operations are supported in 16-bit mixed-precision and have to be done in FP32. Turning on TF32 replaces FP32 in computation (storage still in FP32) at speeds similar to FP16. TF32 is supported on Ampere arch and newer. Fun fact is that TF32 is 19-bit format but has "32" in name.
# The flag below controls whether to allow TF32 on matmul.
# This flag defaults to False in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN.
# This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True -
FP32+FP8 (speed & memory) - there is possibility to do computations in 8-bit with FP32 accumulation but its more complicated to setup than AMP natively supported in PyTorch. Use HF Accelerate library to do this (from what I understand its a wrapper on 3 other packages [
TransformersEngine,MS-AMP, andtorchao] but not only this).
Computational Efficiency
Forward/Backward Pass Optimization
- set gradients to
Noneinstead of default0(but this can cause some unexpected behaviors -Noneisn't a number so operations with it producesNaN) - non-global Cross-Entropy calculation reduces memory usage spike at the end (especially beneficial for LLMs). paper
Hardware Utilization
- avoid moving tensors to another device
.to(device), create tensors directly on target device instead. If you don't have any synchronization later in code then you can use.to(non_blocking=True) - use
torch.compile()if it works, for me it usualy don't. torch.backends.cudnn.benchmark = True
Training Dynamics
Initialization & Learning Rate
- init weights properly (relaying on default pytorch init isn't always optimal)
- LR scheduler (OneCycleLR, lr warmup, etc.) and lr search. Similar for BS (batch size warmup etc.)
- some rule regarding relation of LR & BS - (there is also issue of "critical batch size")
Batch Processing
- max out batch size to fill whole VRAM (batch size finder)
- if you can't fit enough batch size for stable training (e.g. bsz=1) then use gradient accumulation. In pure pytorch it boils down to calling opt.step() less frequently what results in effectively higher batch size.
Stability Techniques
- gradient clipping - prevents exploding gradients by capping their magnitude during backpropagation.
- weights clipping - regularization technique, similar to weight decay in some sense but this isn't connected with loss calc.
- start from pretrained model (transfer learning) and swap last layer.
- Freeze whole model except last layer and after few epochs gradually unfreeze rest of the layers (ULMFiT, gradual unfreezing).
- constraint latent space of your model to align with some general pretrained model, stabilizes training and reduces overfitting (can be done by feature matching losses).
Model-Specific Optimizations
LLM Techniques
- SuperBPE - groups frequent word sequences into single tokens, improving efficiency and performance. Common word combinations get treated as one unit by the tokenizer, which reduces the number of easy-to-predict sequences. This creates a more balanced prediction difficulty across tokens, allowing the model to distribute computational effort more effectively. Author explaination
Diffusion Model Techniques
- latent diffusion - diffuse in latent space, not pixel space. VAE encoder latent (diffuse ) VAE decoder. SD1.5 VAE maybe big but you can use TAESD ( compression, 5MB for enc/dec each and minimal computational overhead)
- if you want train on ImageNet: https://huggingface.co/datasets/fal/cosmos-imagenet (compressed to 2.45GB)
- Min-SNR - method of adding a weightning to the loss based on the SNR (signal to noise ratio) of the timestep. It prevents conflicting gradients from different deniosing phases (beggining, mid and final refinements)
Implementation Tricks
PyTorch Data Handling
- use
.as_tensor()rather than.tensor().torch.tensor()always copies data. If you have a numpy array that you want to convert, usetorch.as_tensor()ortorch.from_numpy()to avoid copying the data. - try "channel last" format for tensors and model (NCHW => NHWC), sometimes it's faster. link
- slow dataloader optimizations Simo tweet, PyTorch forum
Other
- normalization layers (stabilize and speeding up training - you can use higher LR)
- turn off bias before BatchNorm (bn already does shift)
- MoE - model architecture which selectively activates only part of the model. memory-computation tradeoff. MoE faster achives same loss under the same computational budged compared to dense models.
- MoD (Mixture of Depths) - learned skip connection for each transformer block, model learns to not waste compute on easy tokens.
- Stochastic Depth - each layer in deep ConvNet have probability of not being dropped from 1.0 (for first layer) to 0.5 (for last layer). Simply dropout whole layers (prevents vanishing gradients, faster training, better performance)
- checkpoint averaging - weighted average of previous checkpoints makes loss landscape more smooth and convex which speeds up training + reduces overfitting (applies to pretraining & finetuning)