PyTorch FSDP
Fully Sharded Data Parallel training for large language models
View on GitHubOverview
PyTorch Fully Sharded Data Parallel (FSDP) enables training of large models by sharding model parameters, gradients, and optimizer states across data-parallel workers.
Key Features
- ZeRO Stage 3 equivalent sharding
- Native PyTorch integration (no external dependencies)
- Activation checkpointing support
- Mixed precision training (BF16/FP16)
Quick Start
# Build the container
docker build -t fsdp-training .
# Launch with Slurm
sbatch slurm/run_fsdp.sbatch
Supported Models
| Model | Parameters | Min GPUs | Tested On |
|---|---|---|---|
| Llama 2 7B | 7B | 8 | p5.48xlarge |
| Llama 2 70B | 70B | 64 | p5.48xlarge |
| Llama 3 405B | 405B | 256 | p5.48xlarge |
Launch Scripts
Slurm
See slurm/ directory for SBATCH scripts configured for multi-node training.
Kubernetes
See kubernetes/ for Volcano/Kubeflow Training Operator manifests.
HyperPod EKS
See hyperpod-eks/ for HyperPod-specific launch configuration.