import struct

import numpy as np

def load(images_file, labels_file):
    # see http://yann.lecun.com/exdb/mnist/ for the file format
    with open(images_file, 'rb') as f:
        # validate magic
        assert struct.unpack('>I', f.read(4))[0] == 0x803
        num = struct.unpack('>I', f.read(4))[0]
        rows = struct.unpack('>I', f.read(4))[0]
        cols = struct.unpack('>I', f.read(4))[0]
        images = np.frombuffer(f.read(), dtype=np.uint8, count = num * rows * cols)
        # split them up into an array of flat image column vectors
        images = images.reshape(num, rows * cols).T

    with open(labels_file, 'rb') as f:
        # validate magic
        assert struct.unpack('>I', f.read(4))[0] == 0x801
        num = struct.unpack('>I', f.read(4))[0]
        labels = np.frombuffer(f.read(), dtype=np.uint8, count=num)

    return images, labels, rows, cols
