#!/usr/bin/python3

import numpy as np
import matplotlib.pyplot as plt

import mnist

train_images, train_labels, rows, cols = mnist.load('train-images-idx3-ubyte', 'train-labels-idx1-ubyte')

# show the first bunch of training data
for i in range(10):
    print(f'train image #{i}: label {train_labels[i]}')
    plt.imshow(train_images[:, i].reshape(rows, cols), cmap='gray')
    plt.show()
