diff --git a/train.py b/train.py index a5eb890..eaa804d 100644 --- a/train.py +++ b/train.py @@ -141,8 +141,10 @@ def main(): print("build vnet") model = vnet.VNet(elu=False, nll=nll) batch_size = args.ngpu*args.batchSz - gpu_ids = range(args.ngpu) - model = nn.parallel.DataParallel(model, device_ids=gpu_ids) + # Use DataParallel only if Cuda is enabled. + if args.cuda: + gpu_ids = range(args.ngpu) + model = nn.parallel.DataParallel(model, device_ids=gpu_ids) if args.resume: if os.path.isfile(args.resume):