Weights & Biases (wandb)/fr: Difference between revisions
Jump to navigation
Jump to search
(Created page with "== Utilisation sur nos grappes ==") |
(Created page with "=== Exemple ===") |
||
Line 21: | Line 21: | ||
|} | |} | ||
=== | === Exemple === | ||
The following is an example of how to use wandb to track experiments on Béluga. To reproduce this on Cedar, it is not necessary to load the module <tt>httpproxy</tt>. | The following is an example of how to use wandb to track experiments on Béluga. To reproduce this on Cedar, it is not necessary to load the module <tt>httpproxy</tt>. |
Revision as of 18:35, 8 February 2021
Weights & Biases (wandb) is a "meta machine learning platform" designed to help AI practitioners and teams build reliable machine learning models for real-world applications by streamlining the machine learning model lifecycle. By using wandb, users can track, compare, explain and reproduce their machine learning experiments.
Utilisation sur nos grappes
Disponibilité
Puisque wandb exige une connexion à l'internet, sa disponibilité sur les nœuds de calcul dépend de la grappe.
Grappe | Disponible | |
---|---|---|
Béluga | oui ✅ | avant d'utiliser wandb, chargez le module httpproxy module: module load httpproxy |
Cedar | oui ✅ | accès internet activé |
Graham | non ❌ | accès internet désactivé sur les nœuds de calcul |
Exemple
The following is an example of how to use wandb to track experiments on Béluga. To reproduce this on Cedar, it is not necessary to load the module httpproxy.
File : wandb-test.sh
#!/bin/bash
#SBATCH --cpus-per-task=1
#SBATCH --mem=2G
#SBATCH --time=0-03:00
#SBATCH --output=%N-%j.out
module load python/3.6 httpproxy
virtualenv --no-download $SLURM_TMPDIR/env
source $SLURM_TMPDIR/env/bin/activate
pip install torchvision wandb --no-index
### Save your wandb API key in your .bash_profile or replace $API_KEY with your actual API key:
wandb login $API_KEY
python wandb-test.py
The script wandb-test.py uses the watch() method to log default metrics to Weights & Biases. See their full documentation for more options.
File : wandb-test.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import argparse
import wandb
parser = argparse.ArgumentParser(description='cifar10 classification models, wandb test')
parser.add_argument('--lr', default=0.1, help='')
parser.add_argument('--batch_size', type=int, default=768, help='')
parser.add_argument('--max_epochs', type=int, default=4, help='')
parser.add_argument('--num_workers', type=int, default=0, help='')
def main():
args = parser.parse_args()
print("Starting Wandb...")
wandb.init(project="wandb-pytorch-test", config=args)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
transform_train = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train = CIFAR10(root='./data', train=True, download=False, transform=transform_train)
train_loader = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr)
wandb.watch(net)
for epoch in range(args.max_epochs):
train(epoch, net, criterion, optimizer, train_loader)
def train(epoch, net, criterion, optimizer, train_loader):
for batch_idx, (inputs, targets) in enumerate(train_loader):
outputs = net(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if __name__=='__main__':
main()