cc_staff
282
edits
No edit summary |
|||
Line 175: | Line 175: | ||
main() | main() | ||
}} | }} | ||
== Data Parallelism with Multiple GPUs == | |||
Data Parallelism, in this context, refers to methods to perform training over multiple replicas of a model in parallel, where each replica receives a different chunk of training data at each iteration. Gradients are then aggregated at the end of an iteration and the parameters of all replicas are updated in a synchronous or asynchronous fashion, depending on the method. Using this approach may provide a significant speed-up by iterating through all examples in a large dataset approximately N times faster, where N is the number of model replicas. An important caveat of this approach, is that in order to get a trained model that is equivalent to the same model trained without Data Parallelism, the user must scale either the learning rate or the desired batch size in function of the number of replicas. See [https://discuss.pytorch.org/t/should-we-split-batch-size-according-to-ngpu-per-node-when-distributeddataparallel/72769/13 this discussion] for more information. In the examples that follow, each GPU hosts a replica of your model. Consequently, the model must be small enough to fit inside the memory of a single GPU. | |||
=== Single Node === | |||
{{File | |||
|name=flax-example-multigpu.sh | |||
|lang="bash" | |||
|contents= | |||
#!/bin/bash | |||
#SBATCH --nodes 1 | |||
#SBATCH --tasks-per-node=1 | |||
#SBATCH --cpus-per-task=1 # increase this if using num_workers > 0 to load data in parallel | |||
#SBATCH --gres=gpu:2 | |||
#SBATCH --mem=8G | |||
#SBATCH --time=0:05:00 | |||
#SBATCH --output=%N-%j.out | |||
#SBATCH --account=<your account> | |||
module load python # Using Default Python version - Make sure to choose a version that suits your application | |||
module load cuda | |||
virtualenv --no-download $SLURM_TMPDIR/env | |||
source $SLURM_TMPDIR/env/bin/activate | |||
pip install flax tensorflow torchvision --no-index | |||
echo "starting training..." | |||
python flax-example-multigpu.py | |||
}} | |||
{{File | |||
|name=flax-example-multigpu.py | |||
|lang="python" | |||
|contents= | |||
from flax import linen as nn | |||
from flax.training import train_state | |||
from flax import jax_utils | |||
import jax | |||
import jax.numpy as jnp | |||
import numpy as np | |||
import optax | |||
import time | |||
import os | |||
import functools | |||
from torchvision import transforms | |||
from torchvision.datasets import CIFAR10 | |||
from torch.utils.data import DataLoader | |||
import argparse | |||
parser = argparse.ArgumentParser(description='cifar10 classification models, cpu performance test') | |||
parser.add_argument('--lr', default=0.1, help='') | |||
parser.add_argument('--batch_size', type=int, default=512, help='') | |||
parser.add_argument('--num_workers', type=int, default=0, help='') | |||
def main(): | |||
args = parser.parse_args() | |||
seed = jax.random.PRNGKey(42) | |||
n_devices = jax.local_device_count() # get umber of GPUs available to the job | |||
class Net(nn.Module): | |||
@nn.compact | |||
def __call__(self,x): | |||
x = nn.Conv(features=6, kernel_size=(5, 5))(x) | |||
x = nn.relu(x) | |||
x = nn.max_pool(x, window_shape=(2, 2)) | |||
x = nn.Conv(features=16, kernel_size=(5, 5))(x) | |||
x = nn.relu(x) | |||
x = nn.max_pool(x, window_shape=(2, 2)) | |||
x = x.reshape((x.shape[0], -1)) | |||
x = nn.Dense(features=120)(x) | |||
x = nn.relu(x) | |||
x = nn.Dense(features=84)(x) | |||
x = nn.relu(x) | |||
x = nn.Dense(features=10)(x) | |||
return x | |||
class CastToJnp(object): | |||
def __call__(self, image): | |||
return np.array(image, dtype=jnp.float32) | |||
model = Net() | |||
params = model.init(seed, jnp.ones([3,32,32]))['params'] | |||
optimizer = optax.sgd(args.lr) | |||
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer) | |||
state = jax_utils.replicate(state) # broadcast model replicas to all GPUs | |||
# Neither Flax or JAX provide pre-processing / data loading code. Here we use PyTorch for the job. | |||
transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), CastToJnp()]) | |||
dataset_train = CIFAR10(root='./data', train=True, download=False, transform=transform_train) | |||
train_loader = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_jax) | |||
perf = [] | |||
for batch_idx, (inputs, targets) in enumerate(train_loader): | |||
# Split a batch into "n_devices" sub-batches | |||
inputs = inputs.reshape(n_devices, inputs.shape[0] // n_devices, *inputs.shape[1:]) | |||
targets = targets.reshape(n_devices, targets.shape[0] // n_devices, *targets.shape[1:]) | |||
start = time.time() | |||
state, loss = train_step(state, inputs, targets) | |||
batch_time = time.time() - start | |||
images_per_sec = args.batch_size/batch_time | |||
perf.append(images_per_sec) | |||
print(f"Current Loss: {loss}") | |||
print(f"Images processed per second: {np.mean(perf)}") | |||
# "jax.pmap" parallelizes inputs, function evaluation and outputs over a given axis. | |||
# This axis is first dimension of the inputs by default - it is the number of GPUs in this case. | |||
# jax.map also JIT compiles the function, just like @jax.jit in the single GPU case. | |||
@functools.partial(jax.pmap, axis_name='gpus') | |||
def train_step(state, inputs, targets): | |||
def compute_loss(params): | |||
outputs = state.apply_fn({'params': params}, inputs) | |||
one_hot = jax.nn.one_hot(targets, 10) | |||
loss = jnp.mean(optax.softmax_cross_entropy(logits=outputs, labels=one_hot)) | |||
return loss | |||
backward = jax.value_and_grad(compute_loss) | |||
loss, grads = backward(state.params) | |||
grads = jax.lax.pmean(grads, axis_name='gpus') # compute the average of gradients of all model replicas | |||
loss = jax.lax.pmean(loss, axis_name='gpus') # do the same with the loss | |||
updated_state = state.apply_gradients(grads=grads) # weight update is computed with averaged gradients from all replicas | |||
return updated_state, loss | |||
def collate_jax(batch): | |||
if isinstance(batch[0], np.ndarray): | |||
return np.stack(batch) | |||
elif isinstance(batch[0], (tuple,list)): | |||
transposed = zip(*batch) | |||
return [collate_jax(samples) for samples in transposed] | |||
else: | |||
return np.array(batch) | |||
if __name__=='__main__': | |||
main() | |||
}} | |||
=== Multiple Nodes === |