From: Martin Pitt Date: Sat, 29 Aug 2020 12:58:24 +0000 (+0200) Subject: Add backpropagation batching X-Git-Url: https://piware.de/gitweb/?p=handwriting-recognition.git;a=commitdiff_plain;h=8dcd00e9f8bbfc569c9b29ac06d748320d8bf737 Add backpropagation batching This will eventually speed up learning, once it gets run with random subsets of the training data. --- diff --git a/train.py b/train.py index fc9bb4b..b29effa 100755 --- a/train.py +++ b/train.py @@ -102,6 +102,20 @@ def backpropagate(image_batch, label_batch, eta): biases = [b + eta * db for b, db in zip(biases, dbs)] +def train(images, labels, eta, batch_size=100): + '''Do backpropagation for smaller batches + + This greatly speeds up the learning process, at the expense of finding a more erratic path to the local minimum. + ''' + num_images = images.shape[1] + offset = 0 + while offset < num_images: + images_batch = images[:, offset:offset + batch_size] + labels_batch = labels[offset:offset + batch_size] + backpropagate(images_batch, labels_batch, eta) + offset += batch_size + + def test(): """Count percentage of test inputs which are being recognized correctly""" @@ -122,6 +136,6 @@ print(f'correctly recognized images after initialization: {test()}%') for i in range(1): print(f"round #{i} of learning...") - backpropagate(test_images, test_labels, 1) + train(test_images, test_labels, 1) print(f'correctly recognized images: {test()}%')