cc_staff
282
edits
mNo edit summary |
mNo edit summary |
||
Line 1,487: | Line 1,487: | ||
|lang="python" | |lang="python" | ||
|contents= | |contents= | ||
import torch | import torch | ||
from torch import nn | from torch import nn | ||
import torch.nn.functional as F | import torch.nn.functional as F | ||
import pytorch_lightning as pl | import pytorch_lightning as pl | ||
import torchvision | import torchvision | ||
import torchvision.transforms as transforms | import torchvision.transforms as transforms | ||
Line 1,503: | Line 1,498: | ||
from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||
from deepspeed.ops.adam import DeepSpeedCPUAdam | |||
from deepspeed.ops.adam import | from pytorch_lightning.strategies import DeepSpeedStrategy | ||
from pytorch_lightning. | |||
import argparse | import argparse | ||
parser = argparse.ArgumentParser(description='cifar10 classification models, deepspeed offload to cpu test') | parser = argparse.ArgumentParser(description='cifar10 classification models, deepspeed offload to cpu test') | ||
parser.add_argument('--lr', default=0.1, help='') | parser.add_argument('--lr', default=0.1, help='') | ||
Line 1,518: | Line 1,510: | ||
def main(): | def main(): | ||
print("Starting...") | print("Starting...") | ||
args = parser.parse_args() | |||
args = parser.parse_args() | |||
class ConvPart(nn.Module): | |||
class ConvPart(nn.Module): | |||
def __init__(self): | |||
def __init__(self): | |||
super(ConvPart, self).__init__() | super(ConvPart, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | |||
self.conv1 = nn.Conv2d(3, 6, 5) | |||
self.pool = nn.MaxPool2d(2, 2) | self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||
def forward(self, x): | |||
def forward(self, x): | |||
x = self.pool(self.relu(self.conv1(x))) | x = self.pool(self.relu(self.conv1(x))) | ||
x = self.pool(self.relu(self.conv2(x))) | x = self.pool(self.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | x = x.view(-1, 16 * 5 * 5) | ||
return x | |||
return x | |||
# Dense feedforward part of the model | |||
# Dense feedforward part of the model | |||
class MLPPart(nn.Module): | class MLPPart(nn.Module): | ||
def __init__(self): | |||
def __init__(self): | |||
super(MLPPart, self).__init__() | super(MLPPart, self).__init__() | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | |||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | |||
self.fc2 = nn.Linear(120, 84) | self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | self.fc3 = nn.Linear(84, 10) | ||
self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||
def forward(self, x): | |||
def forward(self, x): | |||
x = self.relu(self.fc1(x)) | x = self.relu(self.fc1(x)) | ||
x = self.relu(self.fc2(x)) | x = self.relu(self.fc2(x)) | ||
x = self.fc3(x) | x = self.fc3(x) | ||
return x | |||
return x | |||
class Net(pl.LightningModule): | |||
class Net(pl.LightningModule): | |||
def __init__(self): | |||
def __init__(self): | |||
super(Net, self).__init__() | super(Net, self).__init__() | ||
self.conv_part = ConvPart() | |||
self.conv_part = ConvPart() | |||
self.mlp_part = MLPPart() | self.mlp_part = MLPPart() | ||
def configure_sharded_model(self): | |||
def configure_sharded_model(self): | |||
self.block = nn.Sequential(self.conv_part, self.mlp_part) | |||
self.block = nn.Sequential(self.conv_part, self.mlp_part) | |||
def forward(self, x): | |||
def forward(self, x): | |||
x = self.block(x) | x = self.block(x) | ||
return x | |||
return x | |||
def training_step(self, batch, batch_idx): | |||
def training_step(self, batch, batch_idx): | |||
x, y = batch | x, y = batch | ||
y_hat = self(x) | y_hat = self(x) | ||
Line 1,601: | Line 1,573: | ||
return loss | return loss | ||
def configure_optimizers(self): | |||
def configure_optimizers(self): | return DeepSpeedCPUAdam(self.parameters()) | ||
return | |||
net = Net() | |||
net = Net() | |||
""" Here we initialize a Trainer() explicitly with 1 node and 2 GPU. | |||
""" Here we initialize a Trainer() explicitly with | |||
To make this script more generic, you can use torch.cuda.device_count() to set the number of GPUs | To make this script more generic, you can use torch.cuda.device_count() to set the number of GPUs | ||
and you can use int(os.environ.get("SLURM_JOB_NUM_NODES")) to set the number of nodes. | and you can use int(os.environ.get("SLURM_JOB_NUM_NODES")) to set the number of nodes. | ||
Line 1,615: | Line 1,584: | ||
which can cause issues due to updating logs too frequently.""" | which can cause issues due to updating logs too frequently.""" | ||
trainer = pl.Trainer(accelerator="gpu", devices=2, num_nodes=1, strategy="deepspeed_stage_3", max_epochs = args.max_epochs) | |||
trainer = pl.Trainer( | |||
transform_train = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |||
dataset_train = CIFAR10(root='./data', train=True, download=False, transform=transform_train) | |||
train_loader = DataLoader(dataset_train, batch_size=args.batch_size, num_workers=args.num_workers) | |||
trainer.fit(net,train_loader) | |||
if __name__=='__main__': | if __name__=='__main__': | ||
main() | main() | ||
}} | }} | ||