X-Git-Url: https://piware.de/gitweb/?p=handwriting-recognition.git;a=blobdiff_plain;f=train.py;fp=train.py;h=5ff76669709f1d2a9549dd25b1f16a7b4c97f632;hp=012e2c7f3527c6917d86fdfa2392b44c12f81d78;hb=72f90468f63b736bca28d0fc5ebf3f7d1989de4f;hpb=1de3cdb5ecba32a8a3b0a02bbf71e883383a689d diff --git a/train.py b/train.py index 012e2c7..5ff7666 100755 --- a/train.py +++ b/train.py @@ -13,15 +13,13 @@ nnet_batch = 10000 # neural network structure: two hidden layers, one output layer # (input)--> [Linear->Sigmoid] -> [Linear->Sigmoid] -->(output) # handle 10,000 vectors at a time -Z1 = nnet.LinearLayer(input_shape=(rows * cols, nnet_batch), n_out=20) +Z1 = nnet.LinearLayer(input_shape=(rows * cols, nnet_batch), n_out=80) A1 = nnet.SigmoidLayer(Z1.Z.shape) -Z2 = nnet.LinearLayer(input_shape=A1.A.shape, n_out=16) -A2 = nnet.SigmoidLayer(Z2.Z.shape) -ZO = nnet.LinearLayer(input_shape=A2.A.shape, n_out=10) +ZO = nnet.LinearLayer(input_shape=A1.A.shape, n_out=10) AO = nnet.SigmoidLayer(ZO.Z.shape) -net = (Z1, A1, Z2, A2, ZO, AO) +net = (Z1, A1, ZO, AO) -res = nnet.forward(net, train_images[:, 0:10000]) +res = nnet.forward(net, test_images[:, 0:10000]) print(f'output vector of first image: {res[:, 0]}') digit, conf = nnet.classify(res[:, 0]) print(f'classification of first image: {digit} with confidence {conf}; real label {test_labels[0]}') @@ -30,6 +28,11 @@ print(f'correctly recognized images after initialization: {nnet.accuracy(net, te train_y = nnet.label_vectors(train_labels, 10) for i in range(100): for batch in range(0, num_train, nnet_batch): - cost = nnet.train(net, train_images[:, batch:(batch + nnet_batch)], train_y[:, batch:(batch + nnet_batch)]) + cost = nnet.train(net, train_images[:, batch:(batch + nnet_batch)], train_y[:, batch:(batch + nnet_batch)], learning_rate=1) print(f'cost after training round {i}: {cost}') print(f'correctly recognized images after training: {nnet.accuracy(net, test_images, test_labels)}%') + +res = nnet.forward(net, test_images[:, 0:10000]) +print(f'output vector of first image: {res[:, 0]}') +digit, conf = nnet.classify(res[:, 0]) +print(f'classification of first image: {digit} with confidence {conf}; real label {test_labels[0]}')