Simplify code
[handwriting-recognition.git] / train.py
index 012e2c7f3527c6917d86fdfa2392b44c12f81d78..8510624665ba1ddc66300e00030624b40d7025ac 100755 (executable)
--- a/train.py
+++ b/train.py
@@ -13,15 +13,13 @@ nnet_batch = 10000
 # neural network structure: two hidden layers, one output layer
 #                   (input)--> [Linear->Sigmoid] -> [Linear->Sigmoid] -->(output)
 # handle 10,000 vectors at a time
-Z1 = nnet.LinearLayer(input_shape=(rows * cols, nnet_batch), n_out=20)
-A1 = nnet.SigmoidLayer(Z1.Z.shape)
-Z2 = nnet.LinearLayer(input_shape=A1.A.shape, n_out=16)
-A2 = nnet.SigmoidLayer(Z2.Z.shape)
-ZO = nnet.LinearLayer(input_shape=A2.A.shape, n_out=10)
-AO = nnet.SigmoidLayer(ZO.Z.shape)
-net = (Z1, A1, Z2, A2, ZO, AO)
+Z1 = nnet.LinearLayer(input_shape=(rows * cols, nnet_batch), n_out=80)
+A1 = nnet.SigmoidLayer(Z1.shape)
+ZO = nnet.LinearLayer(input_shape=A1.shape, n_out=10)
+AO = nnet.SigmoidLayer(ZO.shape)
+net = (Z1, A1, ZO, AO)
 
-res = nnet.forward(net, train_images[:, 0:10000])
+res = nnet.forward(net, test_images[:, 0:10000])
 print(f'output vector of first image: {res[:, 0]}')
 digit, conf = nnet.classify(res[:, 0])
 print(f'classification of first image: {digit} with confidence {conf}; real label {test_labels[0]}')
@@ -30,6 +28,11 @@ print(f'correctly recognized images after initialization: {nnet.accuracy(net, te
 train_y = nnet.label_vectors(train_labels, 10)
 for i in range(100):
     for batch in range(0, num_train, nnet_batch):
-        cost = nnet.train(net, train_images[:, batch:(batch + nnet_batch)], train_y[:, batch:(batch + nnet_batch)])
+        cost = nnet.train(net, train_images[:, batch:(batch + nnet_batch)], train_y[:, batch:(batch + nnet_batch)], learning_rate=1)
     print(f'cost after training round {i}: {cost}')
 print(f'correctly recognized images after training: {nnet.accuracy(net, test_images, test_labels)}%')
+
+res = nnet.forward(net, test_images[:, 0:10000])
+print(f'output vector of first image: {res[:, 0]}')
+digit, conf = nnet.classify(res[:, 0])
+print(f'classification of first image: {digit} with confidence {conf}; real label {test_labels[0]}')