4 import matplotlib.pyplot as plt
8 train_images, train_labels, rows, cols = mnist.load('train-images-idx3-ubyte', 'train-labels-idx1-ubyte')
10 # show the first bunch of training data
12 print(f'train image #{i}: label {train_labels[i]}')
13 plt.imshow(train_images[:, i].reshape(rows, cols), cmap='gray')