]> piware.de Git - handwriting-recognition.git/blobdiff - train.py
Use linearly falling learning rate
[handwriting-recognition.git] / train.py
index 8510624665ba1ddc66300e00030624b40d7025ac..b294ce1dba98765e2a1559f200a221eb54c96513 100755 (executable)
--- a/train.py
+++ b/train.py
@@ -28,7 +28,7 @@ print(f'correctly recognized images after initialization: {nnet.accuracy(net, te
 train_y = nnet.label_vectors(train_labels, 10)
 for i in range(100):
     for batch in range(0, num_train, nnet_batch):
-        cost = nnet.train(net, train_images[:, batch:(batch + nnet_batch)], train_y[:, batch:(batch + nnet_batch)], learning_rate=1)
+        cost = nnet.train(net, train_images[:, batch:(batch + nnet_batch)], train_y[:, batch:(batch + nnet_batch)], learning_rate=(100-i)/100)
     print(f'cost after training round {i}: {cost}')
 print(f'correctly recognized images after training: {nnet.accuracy(net, test_images, test_labels)}%')