--- /dev/null
+import struct
+
+import numpy as np
+
+def load(images_file, labels_file):
+ # see http://yann.lecun.com/exdb/mnist/ for the file format
+ with open(images_file, 'rb') as f:
+ # validate magic
+ assert struct.unpack('>I', f.read(4))[0] == 0x803
+ num = struct.unpack('>I', f.read(4))[0]
+ rows = struct.unpack('>I', f.read(4))[0]
+ cols = struct.unpack('>I', f.read(4))[0]
+ images = np.frombuffer(f.read(), dtype=np.uint8, count = num * rows * cols)
+ # split them up into an array of flat image vectors, so that first axis corresponds to labels
+ images = images.reshape(num, rows * cols)
+
+ with open(labels_file, 'rb') as f:
+ # validate magic
+ assert struct.unpack('>I', f.read(4))[0] == 0x801
+ num = struct.unpack('>I', f.read(4))[0]
+ labels = np.frombuffer(f.read(), dtype=np.uint8, count=num)
+
+ return images, labels, rows, cols