From 35f3889338e0ac9a90e2a0ace40978b106bd5d32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Pokr=C3=BDvka?= Date: Thu, 12 Nov 2020 15:19:10 +0000 Subject: [PATCH] =?UTF-8?q?P=C5=99idat=20=E2=80=9Epages/students/2016/luka?= =?UTF-8?q?s=5Fpokryvka/dp2021/mnist/mnist-dist.py=E2=80=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../lukas_pokryvka/dp2021/mnist/mnist-dist.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 pages/students/2016/lukas_pokryvka/dp2021/mnist/mnist-dist.py diff --git a/pages/students/2016/lukas_pokryvka/dp2021/mnist/mnist-dist.py b/pages/students/2016/lukas_pokryvka/dp2021/mnist/mnist-dist.py new file mode 100644 index 000000000..789bf6c8d --- /dev/null +++ b/pages/students/2016/lukas_pokryvka/dp2021/mnist/mnist-dist.py @@ -0,0 +1,105 @@ +import os +from datetime import datetime +import argparse +import torch.multiprocessing as mp +import torchvision +import torchvision.transforms as transforms +import torch +import torch.nn as nn +import torch.distributed as dist +from apex.parallel import DistributedDataParallel as DDP +from apex import amp + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', + help='number of data loading workers (default: 4)') + parser.add_argument('-g', '--gpus', default=1, type=int, + help='number of gpus per node') + parser.add_argument('-nr', '--nr', default=0, type=int, + help='ranking within the nodes') + parser.add_argument('--epochs', default=2, type=int, metavar='N', + help='number of total epochs to run') + args = parser.parse_args() + args.world_size = args.gpus * args.nodes + os.environ['MASTER_ADDR'] = '147.232.47.114' + os.environ['MASTER_PORT'] = '8888' + mp.spawn(train, nprocs=args.gpus, args=(args,)) + + +class ConvNet(nn.Module): + def __init__(self, num_classes=10): + super(ConvNet, self).__init__() + self.layer1 = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2)) + self.layer2 = nn.Sequential( + nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2)) + self.fc = nn.Linear(7*7*32, num_classes) + + def forward(self, x): + out = self.layer1(x) + out = self.layer2(out) + out = out.reshape(out.size(0), -1) + out = self.fc(out) + return out + + +def train(gpu, args): + rank = args.nr * args.gpus + gpu + dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank) + torch.manual_seed(0) + model = ConvNet() + torch.cuda.set_device(gpu) + model.cuda(gpu) + batch_size = 10 + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(gpu) + optimizer = torch.optim.SGD(model.parameters(), 1e-4) + # Wrap the model + model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) + # Data loading code + train_dataset = torchvision.datasets.MNIST(root='./data', + train=True, + transform=transforms.ToTensor(), + download=True) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, + num_replicas=args.world_size, + rank=rank) + train_loader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + sampler=train_sampler) + + start = datetime.now() + total_step = len(train_loader) + for epoch in range(args.epochs): + for i, (images, labels) in enumerate(train_loader): + images = images.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + # Forward pass + outputs = model(images) + loss = criterion(outputs, labels) + + # Backward and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + if (i + 1) % 100 == 0 and gpu == 0: + print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step, + loss.item())) + if gpu == 0: + print("Training complete in: " + str(datetime.now() - start)) + + +if __name__ == '__main__': + torch.multiprocessing.set_start_method('spawn') + main() \ No newline at end of file