JAX PaxML

JAX-based distributed training with Google Pax framework

View on GitHub

Overview

JAX with PaxML enables distributed training leveraging XLA compilation and automatic parallelism strategies.

Quick Start

# Build the container
docker build -f jax_paxml.Dockerfile -t jax-training .

# Launch with Slurm
sbatch jax.sbatch