5 def load(images_file, labels_file):
6 # see http://yann.lecun.com/exdb/mnist/ for the file format
7 with open(images_file, 'rb') as f:
9 assert struct.unpack('>I', f.read(4))[0] == 0x803
10 num = struct.unpack('>I', f.read(4))[0]
11 rows = struct.unpack('>I', f.read(4))[0]
12 cols = struct.unpack('>I', f.read(4))[0]
13 images = np.frombuffer(f.read(), dtype=np.uint8, count = num * rows * cols)
14 # split them up into an array of flat image column vectors
15 images = images.reshape(num, rows * cols).T
17 with open(labels_file, 'rb') as f:
19 assert struct.unpack('>I', f.read(4))[0] == 0x801
20 num = struct.unpack('>I', f.read(4))[0]
21 labels = np.frombuffer(f.read(), dtype=np.uint8, count=num)
23 return images, labels, rows, cols