4c2832cb72bfbf2c336feae5dc57cb5077fac927
[handwriting-recognition.git] / mnist.py
1 import struct
2
3 import numpy as np
4
5 def load(images_file, labels_file):
6     # see http://yann.lecun.com/exdb/mnist/ for the file format
7     with open(images_file, 'rb') as f:
8         # validate magic
9         assert struct.unpack('>I', f.read(4))[0] == 0x803
10         num = struct.unpack('>I', f.read(4))[0]
11         rows = struct.unpack('>I', f.read(4))[0]
12         cols = struct.unpack('>I', f.read(4))[0]
13         images = np.frombuffer(f.read(), dtype=np.uint8, count = num * rows * cols)
14         # split them up into an array of flat image column vectors
15         images = images.reshape(num, rows * cols).T
16
17     with open(labels_file, 'rb') as f:
18         # validate magic
19         assert struct.unpack('>I', f.read(4))[0] == 0x801
20         num = struct.unpack('>I', f.read(4))[0]
21         labels = np.frombuffer(f.read(), dtype=np.uint8, count=num)
22
23     return images, labels, rows, cols