]> piware.de Git - handwriting-recognition.git/blobdiff - train.py
Add backpropagation batching
[handwriting-recognition.git] / train.py
index fc9bb4b251add6108737281c06e6d92b961d80f4..b29effaa799aeb738f932c8e3260facdf6467a86 100755 (executable)
--- a/train.py
+++ b/train.py
@@ -102,6 +102,20 @@ def backpropagate(image_batch, label_batch, eta):
         biases = [b + eta * db for b, db in zip(biases, dbs)]
 
 
+def train(images, labels, eta, batch_size=100):
+    '''Do backpropagation for smaller batches
+
+    This greatly speeds up the learning process, at the expense of finding a more erratic path to the local minimum.
+    '''
+    num_images = images.shape[1]
+    offset = 0
+    while offset < num_images:
+        images_batch = images[:, offset:offset + batch_size]
+        labels_batch = labels[offset:offset + batch_size]
+        backpropagate(images_batch, labels_batch, eta)
+        offset += batch_size
+
+
 def test():
     """Count percentage of test inputs which are being recognized correctly"""
 
@@ -122,6 +136,6 @@ print(f'correctly recognized images after initialization: {test()}%')
 
 for i in range(1):
     print(f"round #{i} of learning...")
-    backpropagate(test_images, test_labels, 1)
+    train(test_images, test_labels, 1)
 
 print(f'correctly recognized images: {test()}%')