From: Martin Pitt Date: Fri, 28 Aug 2020 06:18:29 +0000 (+0200) Subject: Read MNIST db into numpy arrays, display X-Git-Url: https://piware.de/gitweb/?a=commitdiff_plain;h=6bdea63ab81b6b3dae4649478d29563e2160bc13;p=handwriting-recognition.git Read MNIST db into numpy arrays, display --- diff --git a/README.md b/README.md index ad14945..fa1c484 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,13 @@ plt.imshow(grad, cmap='gray') plt.show() plt.imshow(np.sin(np.linspace(0,10000,10000)).reshape(100,100) ** 2, cmap='gray') -# does not work with QT_QPA_PLATFORM=wayland +# non-blocking does not work with QT_QPA_PLATFORM=wayland plt.show(block=False) plt.close() ``` - Get the handwritten digits training data with `./download-mnist.sh` + + - Read the MNIST database into numpy arrays with `./read_display_mnist.py`. Plot the first ten images and show their labels, to make sure the data makes sense: + + ![visualize training data](screenshots/mnist-visualize-training-data.png) diff --git a/mnist.py b/mnist.py new file mode 100644 index 0000000..1a3870f --- /dev/null +++ b/mnist.py @@ -0,0 +1,23 @@ +import struct + +import numpy as np + +def load(images_file, labels_file): + # see http://yann.lecun.com/exdb/mnist/ for the file format + with open(images_file, 'rb') as f: + # validate magic + assert struct.unpack('>I', f.read(4))[0] == 0x803 + num = struct.unpack('>I', f.read(4))[0] + rows = struct.unpack('>I', f.read(4))[0] + cols = struct.unpack('>I', f.read(4))[0] + images = np.frombuffer(f.read(), dtype=np.uint8, count = num * rows * cols) + # split them up into an array of flat image vectors, so that first axis corresponds to labels + images = images.reshape(num, rows * cols) + + with open(labels_file, 'rb') as f: + # validate magic + assert struct.unpack('>I', f.read(4))[0] == 0x801 + num = struct.unpack('>I', f.read(4))[0] + labels = np.frombuffer(f.read(), dtype=np.uint8, count=num) + + return images, labels, rows, cols diff --git a/read_display_mnist.py b/read_display_mnist.py new file mode 100755 index 0000000..fd85ee3 --- /dev/null +++ b/read_display_mnist.py @@ -0,0 +1,14 @@ +#!/usr/bin/python3 + +import numpy as np +import matplotlib.pyplot as plt + +import mnist + +train_images, train_labels, rows, cols = mnist.load('train-images-idx3-ubyte', 'train-labels-idx1-ubyte') + +# show the first bunch of training data +for i in range(10): + print(f'train image #{i}: label {train_labels[i]}') + plt.imshow(train_images[i].reshape(rows, cols), cmap='gray') + plt.show() diff --git a/screenshots/mnist-visualize-training-data.png b/screenshots/mnist-visualize-training-data.png new file mode 100644 index 0000000..fd21a11 Binary files /dev/null and b/screenshots/mnist-visualize-training-data.png differ