diff --git a/notebooks/pytorch_mnist.ipynb b/notebooks/pytorch_mnist.ipynb index 45b382a..4926455 100644 --- a/notebooks/pytorch_mnist.ipynb +++ b/notebooks/pytorch_mnist.ipynb @@ -91,7 +91,7 @@ "class Net(nn.Module):\n", " #This defines the structure of the NN.\n", " def __init__(self):\n", - " NUM_CLASSES = 9\n", + " NUM_CLASSES = 10\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", @@ -156,7 +156,7 @@ " output = model(data)\n", " test_loss += F.nll_loss(output, target, size_average=False).data # sum up batch loss\n", " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", - " correct += pred.eq(target.data.view_as(pre)).long().cpu().sum()\n", + " correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()\n", "\n", " test_loss /= len(test_loader.dataset)\n", " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",