X-Git-Url: https://piware.de/gitweb/?p=handwriting-recognition.git;a=blobdiff_plain;f=mnist.py;fp=mnist.py;h=1a3870f02b6e04b49877af32c44d3c8ac24804fb;hp=0000000000000000000000000000000000000000;hb=6bdea63ab81b6b3dae4649478d29563e2160bc13;hpb=6896c756f06ff0ca4a36d39e9ad901822d5c3d45 diff --git a/mnist.py b/mnist.py new file mode 100644 index 0000000..1a3870f --- /dev/null +++ b/mnist.py @@ -0,0 +1,23 @@ +import struct + +import numpy as np + +def load(images_file, labels_file): + # see http://yann.lecun.com/exdb/mnist/ for the file format + with open(images_file, 'rb') as f: + # validate magic + assert struct.unpack('>I', f.read(4))[0] == 0x803 + num = struct.unpack('>I', f.read(4))[0] + rows = struct.unpack('>I', f.read(4))[0] + cols = struct.unpack('>I', f.read(4))[0] + images = np.frombuffer(f.read(), dtype=np.uint8, count = num * rows * cols) + # split them up into an array of flat image vectors, so that first axis corresponds to labels + images = images.reshape(num, rows * cols) + + with open(labels_file, 'rb') as f: + # validate magic + assert struct.unpack('>I', f.read(4))[0] == 0x801 + num = struct.unpack('>I', f.read(4))[0] + labels = np.frombuffer(f.read(), dtype=np.uint8, count=num) + + return images, labels, rows, cols