]> piware.de Git - handwriting-recognition.git/commitdiff
Read MNIST db into numpy arrays, display
authorMartin Pitt <martin@piware.de>
Fri, 28 Aug 2020 06:18:29 +0000 (08:18 +0200)
committerMartin Pitt <martin@piware.de>
Sun, 30 Aug 2020 09:20:52 +0000 (11:20 +0200)
README.md
mnist.py [new file with mode: 0644]
read_display_mnist.py [new file with mode: 0755]
screenshots/mnist-visualize-training-data.png [new file with mode: 0644]

index ad1494553d16fddaf4c96faa725c0d27c196c459..fa1c48496f16f54aaba8b1d40f637dc5db10e0b4 100644 (file)
--- a/README.md
+++ b/README.md
@@ -27,9 +27,13 @@ plt.imshow(grad, cmap='gray')
 plt.show()
 
 plt.imshow(np.sin(np.linspace(0,10000,10000)).reshape(100,100) ** 2, cmap='gray')
-# does not work with QT_QPA_PLATFORM=wayland
+# non-blocking does not work with QT_QPA_PLATFORM=wayland
 plt.show(block=False)
 plt.close()
 ```
 
  - Get the handwritten digits training data with `./download-mnist.sh`
+
+ - Read the MNIST database into numpy arrays with `./read_display_mnist.py`. Plot the first ten images and show their labels, to make sure the data makes sense:
+
+   ![visualize training data](screenshots/mnist-visualize-training-data.png)
diff --git a/mnist.py b/mnist.py
new file mode 100644 (file)
index 0000000..1a3870f
--- /dev/null
+++ b/mnist.py
@@ -0,0 +1,23 @@
+import struct
+
+import numpy as np
+
+def load(images_file, labels_file):
+    # see http://yann.lecun.com/exdb/mnist/ for the file format
+    with open(images_file, 'rb') as f:
+        # validate magic
+        assert struct.unpack('>I', f.read(4))[0] == 0x803
+        num = struct.unpack('>I', f.read(4))[0]
+        rows = struct.unpack('>I', f.read(4))[0]
+        cols = struct.unpack('>I', f.read(4))[0]
+        images = np.frombuffer(f.read(), dtype=np.uint8, count = num * rows * cols)
+        # split them up into an array of flat image vectors, so that first axis corresponds to labels
+        images = images.reshape(num, rows * cols)
+
+    with open(labels_file, 'rb') as f:
+        # validate magic
+        assert struct.unpack('>I', f.read(4))[0] == 0x801
+        num = struct.unpack('>I', f.read(4))[0]
+        labels = np.frombuffer(f.read(), dtype=np.uint8, count=num)
+
+    return images, labels, rows, cols
diff --git a/read_display_mnist.py b/read_display_mnist.py
new file mode 100755 (executable)
index 0000000..fd85ee3
--- /dev/null
@@ -0,0 +1,14 @@
+#!/usr/bin/python3
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+import mnist
+
+train_images, train_labels, rows, cols = mnist.load('train-images-idx3-ubyte', 'train-labels-idx1-ubyte')
+
+# show the first bunch of training data
+for i in range(10):
+    print(f'train image #{i}: label {train_labels[i]}')
+    plt.imshow(train_images[i].reshape(rows, cols), cmap='gray')
+    plt.show()
diff --git a/screenshots/mnist-visualize-training-data.png b/screenshots/mnist-visualize-training-data.png
new file mode 100644 (file)
index 0000000..fd21a11
Binary files /dev/null and b/screenshots/mnist-visualize-training-data.png differ