Flax

Revision as of 16:39, 18 July 2022 by Lucasn1 (talk | contribs) (Created page with "[https://flax.readthedocs.io/en/latest/index.html Flax] is a neural network library and ecosystem for [https://jax.readthedocs.io/en/latest/index.html JAX] that is designed fo...")
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)

Flax is a neural network library and ecosystem for JAX that is designed for flexibility. Its API for building models is similar to that of PyTorch and Keras where models are expressed as sequences of modules. Similarities however stop there - being based on JAX, Flax's API for training models is designed around functional programming.

Installation

Latest available wheels

To see the latest version of Flax that we have built:

 
[name@server ~]$ avail_wheels "flax*"

For more information, see Available wheels.

Installing the Compute Canada wheel

The preferred option is to install it using the Python wheel as follows:

1. Load a Python module, thus module load python
2. Create and start a virtual environment.
3. Install Flax in the virtual environment with pip install.
 
(venv) [name@server ~] pip install --no-index flax

High Performance with Flax

Flax with Multiple CPUs or a Single GPU

As a framework based on JAX, Flax derives its high-performance from the combination of a functional paradigm, automatic differentiation and TensorFlow's Accelerated Linear Algebra (XLA) compiler. Concretely, one can use JAX's Just-In-Time compiler to leverage XLA on code blocks (often compositions of functions) that are called repeatedly during a training loop, like loss computation, backpropagation and gradient updates. Another advantage this provides is that XLA handles compiling code blocks into CPU or GPU code transparently, so your Python code is exactly the same regardless of the device where it will be executed.

With the above being said, when training small scale models we strongly recommend using multiple CPUs instead of using a GPU. While training will almost certainly run faster on a GPU (except in cases where the model is very small), if your model and your dataset are not large enough, the speed up relative to CPU will likely not be very significant and your job will end up using only a small portion of the GPU's compute capabilities. This might not be an issue on your own workstation, but in a shared environment like our HPC clusters this means you are unnecessarily blocking a resource that another user may need to run actual large scale computations! Furthermore, you would be unnecessarily using up your group's allocation and affecting the priority of your colleagues' jobs.

Simply put, you should not ask for a GPU if your code is not capable of making a reasonable use of its compute capacity. The following example illustrates how to submit a Flax job with or without a GPU:


File : flax-example.sh

#!/bin/bash
#SBATCH --nodes 1
#SBATCH --tasks-per-node=1 
#SBATCH --cpus-per-task=1 # change this parameter to 2,4,6,... to see the effect on performance
#SBATCH --gres=gpu:1 # Remove this line to run using CPU only 

#SBATCH --mem=8G      
#SBATCH --time=0:05:00
#SBATCH --output=%N-%j.out
#SBATCH --account=<your account>

module load python # Using Default Python version - Make sure to choose a version that suits your application

module load cuda # Remove this line if not using a GPU

virtualenv --no-download $SLURM_TMPDIR/env
source $SLURM_TMPDIR/env/bin/activate
pip install flax tensorflow torchvision --no-index

echo "starting training..."

python flax-example.py



File : flax-example.py

from flax import linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import numpy as np
import optax

from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

import argparse
import time

parser = argparse.ArgumentParser(description='cifar10 classification models, cpu performance test')
parser.add_argument('--lr', default=0.1, help='')
parser.add_argument('--batch_size', type=int, default=512, help='')
parser.add_argument('--num_workers', type=int, default=0, help='')


def main():

   args = parser.parse_args()

   seed = jax.random.PRNGKey(42)

   class Net(nn.Module):

      @nn.compact
      def __call__(self,x):
         x = nn.Conv(features=6, kernel_size=(5, 5))(x)
         x = nn.relu(x)
         x = nn.max_pool(x, window_shape=(2, 2))
         x = nn.Conv(features=16, kernel_size=(5, 5))(x)
         x = nn.relu(x)
         x = nn.max_pool(x, window_shape=(2, 2))
         x = x.reshape((x.shape[0], -1))
         x = nn.Dense(features=120)(x)
         x = nn.relu(x)
         x = nn.Dense(features=84)(x)
         x = nn.relu(x)
         x = nn.Dense(features=10)(x)

         return x
   
   # Helper class to cast numpy arrays to JAX arrays
   class CastToJnp(object):
      def __call__(self, image):
         return np.array(image, dtype=jnp.float32)

   model = Net()

   params = model.init(seed, jnp.ones([3,32,32]))['params'] # initialize weights. Function takes in jnp.ones() of the same shape as the Model's inputs.
   optimizer = optax.sgd(args.lr)
   state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

   # Neither Flax or JAX provide pre-processing / data loading code. Here we use PyTorch for the job.

   transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), CastToJnp()])
   dataset_train = CIFAR10(root='./data', train=True, download=False, transform=transform_train)
   train_loader = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_jax)

   perf = []

   for batch_idx, (inputs, targets) in enumerate(train_loader):

      start = time.time()

      grads,loss = train_step(state, inputs, targets)
      state = update_state(state, grads)

      batch_time = time.time() - start
      images_per_sec = args.batch_size/batch_time
      perf.append(images_per_sec)

      print(f"Current Loss: {loss}")

   print(f"Images processed per second: {np.mean(perf)}")


# JIT Compile an entire training step: Forward and Backward in a single function call:
@jax.jit
def train_step(state, inputs, targets):

   def compute_loss(params):
      outputs = state.apply_fn({'params': params}, inputs)
      one_hot = jax.nn.one_hot(targets, 10)
      loss = jnp.mean(optax.softmax_cross_entropy(logits=outputs, labels=one_hot))
      return loss

   backward = jax.value_and_grad(compute_loss) # returns a function that autograds its input
   loss, grads = backward(state.params)
   return grads, loss

# JIT Compile weight update:
@jax.jit
def update_state(state, grads):
   return state.apply_gradients(grads=grads)

#Taken from https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
def collate_jax(batch):
   if isinstance(batch[0], np.ndarray):
      return np.stack(batch)
   elif isinstance(batch[0], (tuple,list)):
      transposed = zip(*batch)
      return [collate_jax(samples) for samples in transposed]
   else:
      return np.array(batch)

if __name__=='__main__':
   main()