X-Git-Url: https://piware.de/gitweb/?p=handwriting-recognition.git;a=blobdiff_plain;f=mnist.py;h=4c2832cb72bfbf2c336feae5dc57cb5077fac927;hp=1a3870f02b6e04b49877af32c44d3c8ac24804fb;hb=8af4223121b60d5d67b7121d87c5c6fed01b58e7;hpb=6bdea63ab81b6b3dae4649478d29563e2160bc13 diff --git a/mnist.py b/mnist.py index 1a3870f..4c2832c 100644 --- a/mnist.py +++ b/mnist.py @@ -11,8 +11,8 @@ def load(images_file, labels_file): rows = struct.unpack('>I', f.read(4))[0] cols = struct.unpack('>I', f.read(4))[0] images = np.frombuffer(f.read(), dtype=np.uint8, count = num * rows * cols) - # split them up into an array of flat image vectors, so that first axis corresponds to labels - images = images.reshape(num, rows * cols) + # split them up into an array of flat image column vectors + images = images.reshape(num, rows * cols).T with open(labels_file, 'rb') as f: # validate magic