PyTorch: Difference between revisions

Jump to navigation Jump to search
304 bytes removed ,  1 year ago
m
no edit summary
mNo edit summary
mNo edit summary
Line 1,487: Line 1,487:
   |lang="python"
   |lang="python"
   |contents=
   |contents=
import datetime
<!--T:462-->
import torch
import torch
from torch import nn
from torch import nn
import torch.nn.functional as F
import torch.nn.functional as F


<!--T:463-->
import pytorch_lightning as pl
import pytorch_lightning as pl


<!--T:464-->
import torchvision
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms as transforms
Line 1,503: Line 1,498:
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader


<!--T:465-->
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.adam import FusedAdam
from pytorch_lightning.strategies import DeepSpeedStrategy
from pytorch_lightning.plugins import DeepSpeedPlugin


<!--T:466-->
import argparse
import argparse


<!--T:467-->
parser = argparse.ArgumentParser(description='cifar10 classification models, deepspeed offload to cpu test')
parser = argparse.ArgumentParser(description='cifar10 classification models, deepspeed offload to cpu test')
parser.add_argument('--lr', default=0.1, help='')
parser.add_argument('--lr', default=0.1, help='')
Line 1,518: Line 1,510:




<!--T:468-->
def main():
def main():
     print("Starting...")
     print("Starting...")


<!--T:469-->
    args = parser.parse_args()
args = parser.parse_args()


<!--T:470-->
    class ConvPart(nn.Module):
class ConvPart(nn.Module):


<!--T:471-->
      def __init__(self):
def __init__(self):
           super(ConvPart, self).__init__()
           super(ConvPart, self).__init__()


<!--T:472-->
          self.conv1 = nn.Conv2d(3, 6, 5)
self.conv1 = nn.Conv2d(3, 6, 5)
           self.pool = nn.MaxPool2d(2, 2)
           self.pool = nn.MaxPool2d(2, 2)
           self.conv2 = nn.Conv2d(6, 16, 5)
           self.conv2 = nn.Conv2d(6, 16, 5)
           self.relu = nn.ReLU()
           self.relu = nn.ReLU()


<!--T:473-->
      def forward(self, x):
def forward(self, x):
           x = self.pool(self.relu(self.conv1(x)))
           x = self.pool(self.relu(self.conv1(x)))
           x = self.pool(self.relu(self.conv2(x)))
           x = self.pool(self.relu(self.conv2(x)))
           x = x.view(-1, 16 * 5 * 5)
           x = x.view(-1, 16 * 5 * 5)


<!--T:474-->
          return x
return x


<!--T:475-->
    # Dense feedforward part of the model
# Dense feedforward part of the model
     class MLPPart(nn.Module):
     class MLPPart(nn.Module):


<!--T:476-->
      def __init__(self):
def __init__(self):
           super(MLPPart, self).__init__()
           super(MLPPart, self).__init__()


<!--T:477-->
          self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
           self.fc2 = nn.Linear(120, 84)
           self.fc2 = nn.Linear(120, 84)
           self.fc3 = nn.Linear(84, 10)
           self.fc3 = nn.Linear(84, 10)
           self.relu = nn.ReLU()
           self.relu = nn.ReLU()


<!--T:478-->
      def forward(self, x):
def forward(self, x):
           x = self.relu(self.fc1(x))
           x = self.relu(self.fc1(x))
           x = self.relu(self.fc2(x))
           x = self.relu(self.fc2(x))
           x = self.fc3(x)
           x = self.fc3(x)


<!--T:479-->
          return x
return x


<!--T:480-->
    class Net(pl.LightningModule):
class Net(pl.LightningModule):


<!--T:481-->
      def __init__(self):
def __init__(self):
           super(Net, self).__init__()
           super(Net, self).__init__()


<!--T:482-->
          self.conv_part = ConvPart()
self.conv_part = ConvPart()
           self.mlp_part = MLPPart()
           self.mlp_part = MLPPart()


<!--T:483-->
      def configure_sharded_model(self):
def configure_sharded_model(self):


<!--T:484-->
          self.block = nn.Sequential(self.conv_part, self.mlp_part)
self.block = nn.Sequential(self.conv_part, self.mlp_part)


<!--T:485-->
      def forward(self, x):
def forward(self, x):
           x = self.block(x)
           x = self.block(x)


<!--T:486-->
          return x
return x


<!--T:487-->
      def training_step(self, batch, batch_idx):
def training_step(self, batch, batch_idx):
           x, y = batch
           x, y = batch
           y_hat = self(x)
           y_hat = self(x)
Line 1,601: Line 1,573:
           return loss
           return loss


<!--T:488-->
      def configure_optimizers(self):
def configure_optimizers(self):
           return DeepSpeedCPUAdam(self.parameters())
           return FusedAdam(self.parameters())


<!--T:489-->
    net = Net()
net = Net()


<!--T:490-->
    """ Here we initialize a Trainer() explicitly with 1 node and 2 GPU.
""" Here we initialize a Trainer() explicitly with 2 nodes and 2 GPUs per node.
         To make this script more generic, you can use torch.cuda.device_count() to set the number of GPUs
         To make this script more generic, you can use torch.cuda.device_count() to set the number of GPUs
         and you can use int(os.environ.get("SLURM_JOB_NUM_NODES")) to set the number of nodes.  
         and you can use int(os.environ.get("SLURM_JOB_NUM_NODES")) to set the number of nodes.  
Line 1,615: Line 1,584:
         which can cause issues due to updating logs too frequently."""
         which can cause issues due to updating logs too frequently."""


<!--T:491-->
    trainer = pl.Trainer(accelerator="gpu", devices=2, num_nodes=1, strategy="deepspeed_stage_3", max_epochs = args.max_epochs)
trainer = pl.Trainer(gpus=2, num_nodes=2,strategy="deepspeed_stage_3", max_epochs = args.max_epochs,progress_bar_refresh_rate=0)
 
    transform_train = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


<!--T:492-->
    dataset_train = CIFAR10(root='./data', train=True, download=False, transform=transform_train)
transform_train = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


<!--T:493-->
    train_loader = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers)
dataset_train = CIFAR10(root='./data', train=True, download=False, transform=transform_train)


<!--T:494-->
    trainer.fit(net,train_loader)
train_loader = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers)


<!--T:495-->
trainer.fit(net,train_loader)


<!--T:496-->
if __name__=='__main__':
if __name__=='__main__':
   main()
   main()


<!--T:497-->
}}
}}


cc_staff
282

edits

Navigation menu