]> piware.de Git - handwriting-recognition.git/commitdiff
Rearange image vector
authorMartin Pitt <martin@piware.de>
Sat, 29 Aug 2020 05:45:10 +0000 (07:45 +0200)
committerMartin Pitt <martin@piware.de>
Sun, 30 Aug 2020 09:20:52 +0000 (11:20 +0200)
Put each image into a column instead of a row, which works much better
with the standard formulation of backpropagation algorithms.

mnist.py
read_display_mnist.py

index 1a3870f02b6e04b49877af32c44d3c8ac24804fb..4c2832cb72bfbf2c336feae5dc57cb5077fac927 100644 (file)
--- a/mnist.py
+++ b/mnist.py
@@ -11,8 +11,8 @@ def load(images_file, labels_file):
         rows = struct.unpack('>I', f.read(4))[0]
         cols = struct.unpack('>I', f.read(4))[0]
         images = np.frombuffer(f.read(), dtype=np.uint8, count = num * rows * cols)
-        # split them up into an array of flat image vectors, so that first axis corresponds to labels
-        images = images.reshape(num, rows * cols)
+        # split them up into an array of flat image column vectors
+        images = images.reshape(num, rows * cols).T
 
     with open(labels_file, 'rb') as f:
         # validate magic
index fd85ee3887941fee01f3cd1259ea7b5a6dac55a5..3030a0b50cc7b19142d6c1ca9b570e6c706e96b7 100755 (executable)
@@ -10,5 +10,5 @@ train_images, train_labels, rows, cols = mnist.load('train-images-idx3-ubyte', '
 # show the first bunch of training data
 for i in range(10):
     print(f'train image #{i}: label {train_labels[i]}')
-    plt.imshow(train_images[i].reshape(rows, cols), cmap='gray')
+    plt.imshow(train_images[:, i].reshape(rows, cols), cmap='gray')
     plt.show()