X-Git-Url: https://piware.de/gitweb/?p=handwriting-recognition.git;a=blobdiff_plain;f=train.py;fp=train.py;h=b294ce1dba98765e2a1559f200a221eb54c96513;hp=8510624665ba1ddc66300e00030624b40d7025ac;hb=59f4fd752941f39ddcd36760202a0dc742747106;hpb=0b40285a04bfbf2d73f7a7154eacb4613f08b350 diff --git a/train.py b/train.py index 8510624..b294ce1 100755 --- a/train.py +++ b/train.py @@ -28,7 +28,7 @@ print(f'correctly recognized images after initialization: {nnet.accuracy(net, te train_y = nnet.label_vectors(train_labels, 10) for i in range(100): for batch in range(0, num_train, nnet_batch): - cost = nnet.train(net, train_images[:, batch:(batch + nnet_batch)], train_y[:, batch:(batch + nnet_batch)], learning_rate=1) + cost = nnet.train(net, train_images[:, batch:(batch + nnet_batch)], train_y[:, batch:(batch + nnet_batch)], learning_rate=(100-i)/100) print(f'cost after training round {i}: {cost}') print(f'correctly recognized images after training: {nnet.accuracy(net, test_images, test_labels)}%')