diff --git a/mnist/load.py b/mnist/load.py index 67aea67..6d4f85e 100644 --- a/mnist/load.py +++ b/mnist/load.py @@ -12,19 +12,19 @@ from lib.config import data_dir def mnist(): - fd = open(os.path.join(data_dir,'train-images.idx3-ubyte')) + fd = open(os.path.join(data_dir,'train-images-idx3-ubyte')) loaded = np.fromfile(file=fd,dtype=np.uint8) trX = loaded[16:].reshape((60000,28*28)).astype(float) - fd = open(os.path.join(data_dir,'train-labels.idx1-ubyte')) + fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte')) loaded = np.fromfile(file=fd,dtype=np.uint8) trY = loaded[8:].reshape((60000)) - fd = open(os.path.join(data_dir,'t10k-images.idx3-ubyte')) + fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte')) loaded = np.fromfile(file=fd,dtype=np.uint8) teX = loaded[16:].reshape((10000,28*28)).astype(float) - fd = open(os.path.join(data_dir,'t10k-labels.idx1-ubyte')) + fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte')) loaded = np.fromfile(file=fd,dtype=np.uint8) teY = loaded[8:].reshape((10000))