]> piware.de Git - handwriting-recognition.git/blobdiff - mnist.py
Read MNIST db into numpy arrays, display
[handwriting-recognition.git] / mnist.py
diff --git a/mnist.py b/mnist.py
new file mode 100644 (file)
index 0000000..1a3870f
--- /dev/null
+++ b/mnist.py
@@ -0,0 +1,23 @@
+import struct
+
+import numpy as np
+
+def load(images_file, labels_file):
+    # see http://yann.lecun.com/exdb/mnist/ for the file format
+    with open(images_file, 'rb') as f:
+        # validate magic
+        assert struct.unpack('>I', f.read(4))[0] == 0x803
+        num = struct.unpack('>I', f.read(4))[0]
+        rows = struct.unpack('>I', f.read(4))[0]
+        cols = struct.unpack('>I', f.read(4))[0]
+        images = np.frombuffer(f.read(), dtype=np.uint8, count = num * rows * cols)
+        # split them up into an array of flat image vectors, so that first axis corresponds to labels
+        images = images.reshape(num, rows * cols)
+
+    with open(labels_file, 'rb') as f:
+        # validate magic
+        assert struct.unpack('>I', f.read(4))[0] == 0x801
+        num = struct.unpack('>I', f.read(4))[0]
+        labels = np.frombuffer(f.read(), dtype=np.uint8, count=num)
+
+    return images, labels, rows, cols