diff --git a/mnist/main.py b/mnist/main.py index 184dc4744f..2477494642 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -3,13 +3,14 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from torchvision import datasets, transforms +from torchvision import datasets +from torchvision.transforms import v2 as transforms from torch.optim.lr_scheduler import StepLR class Net(nn.Module): def __init__(self): - super(Net, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout(0.25) @@ -33,19 +34,42 @@ def forward(self, x): return output -def train(args, model, device, train_loader, optimizer, epoch): +def train_amp(args, model, device, train_loader, opt, epoch, scaler): model.train() for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() + data, target = data.to(device, memory_format=torch.channels_last), target.to( + device + ) + opt.zero_grad() + with torch.autocast(device_type=device.type): + output = model(data) + loss = F.nll_loss(output, target) + scaler.scale(loss).backward() + scaler.step(opt) + scaler.update() + if batch_idx % args.log_interval == 0: + print( + f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}" + ) + if args.dry_run: + break + + +def train(args, model, device, train_loader, opt, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device, memory_format=torch.channels_last), target.to( + device + ) + opt.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() - optimizer.step() + opt.step() if batch_idx % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) + print( + f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}" + ) if args.dry_run: break @@ -58,48 +82,128 @@ def test(model, device, test_loader): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + test_loss += F.nll_loss( + output, target, reduction="sum" + ).item() # sum up batch loss + pred = output.argmax( + dim=1, keepdim=True + ) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * correct / len(test_loader.dataset):.0f}%)\n" + ) -def main(): +def parse_args(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=14, metavar='N', - help='number of epochs to train (default: 14)') - parser.add_argument('--lr', type=float, default=1.0, metavar='LR', - help='learning rate (default: 1.0)') - parser.add_argument('--gamma', type=float, default=0.7, metavar='M', - help='Learning rate step gamma (default: 0.7)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--no-mps', action='store_true', default=False, - help='disables macOS GPU training') - parser.add_argument('--dry-run', action='store_true', default=False, - help='quickly check a single pass') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument('--save-model', action='store_true', default=False, - help='For Saving the current Model') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=14, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", + type=float, + default=1.0, + metavar="LR", + help="learning rate (default: 1.0)", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--no-mps", + action="store_true", + default=False, + help="disables macOS GPU training", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--use-amp", + type=bool, + default=False, + help="use automatic mixed precision", + ) + parser.add_argument( + "--compile-backend", + type=str, + default="inductor", + metavar="BACKEND", + help="backend to compile the model with", + ) + parser.add_argument( + "--compile-mode", + type=str, + default="default", + metavar="MODE", + help="compilation mode", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=False, + help="For Saving the current Model", + ) + parser.add_argument( + "--data-dir", + type=str, + default="../data", + metavar="DIR", + help="path to the data directory", + ) args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() use_mps = not args.no_mps and torch.backends.mps.is_available() - torch.manual_seed(args.seed) - if use_cuda: device = torch.device("cuda") elif use_mps: @@ -107,32 +211,43 @@ def main(): else: device = torch.device("cpu") - train_kwargs = {'batch_size': args.batch_size} - test_kwargs = {'batch_size': args.test_batch_size} + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} if use_cuda: - cuda_kwargs = {'num_workers': 1, - 'pin_memory': True, - 'shuffle': True} + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} train_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs) - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - dataset1 = datasets.MNIST('../data', train=True, download=True, - transform=transform) - dataset2 = datasets.MNIST('../data', train=False, - transform=transform) - train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + transform = transforms.Compose( + [ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + transforms.Normalize(mean=(0.1307,), std=(0.3081,)), + ] + ) + + data_dir = args.data_dir + + dataset1 = datasets.MNIST(data_dir, train=True, download=True, transform=transform) + dataset2 = datasets.MNIST(data_dir, train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) - model = Net().to(device) - optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + model = Net().to(device, memory_format=torch.channels_last) + model = torch.compile(model, backend=args.compile_backend, mode=args.compile_mode) + optimizer = optim.Adadelta(model.parameters(), lr=torch.tensor(args.lr)) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + + scaler = None + if use_cuda and args.use_amp: + scaler = torch.GradScaler(device=device) + for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch) + if scaler is None: + train(args, model, device, train_loader, optimizer, epoch) + else: + train_amp(args, model, device, train_loader, optimizer, epoch, scaler) test(model, device, test_loader) scheduler.step() @@ -140,5 +255,6 @@ def main(): torch.save(model.state_dict(), "mnist_cnn.pt") -if __name__ == '__main__': +if __name__ == "__main__": main() +