PyTorch FSDP

Fully Sharded Data Parallel training for large language models

View on GitHub

Overview

PyTorch Fully Sharded Data Parallel (FSDP) enables training of large models by sharding model parameters, gradients, and optimizer states across data-parallel workers.

Key Features

  • ZeRO Stage 3 equivalent sharding
  • Native PyTorch integration (no external dependencies)
  • Activation checkpointing support
  • Mixed precision training (BF16/FP16)

Quick Start

# Build the container
docker build -t fsdp-training .

# Launch with Slurm
sbatch slurm/run_fsdp.sbatch

Supported Models

ModelParametersMin GPUsTested On
Llama 2 7B7B8p5.48xlarge
Llama 2 70B70B64p5.48xlarge
Llama 3 405B405B256p5.48xlarge

Launch Scripts

Slurm

See slurm/ directory for SBATCH scripts configured for multi-node training.

Kubernetes

See kubernetes/ for Volcano/Kubeflow Training Operator manifests.

HyperPod EKS

See hyperpod-eks/ for HyperPod-specific launch configuration.