X-Git-Url: https://piware.de/gitweb/?p=handwriting-recognition.git;a=blobdiff_plain;f=train.py;h=8510624665ba1ddc66300e00030624b40d7025ac;hp=5ff76669709f1d2a9549dd25b1f16a7b4c97f632;hb=0b40285a04bfbf2d73f7a7154eacb4613f08b350;hpb=72f90468f63b736bca28d0fc5ebf3f7d1989de4f diff --git a/train.py b/train.py index 5ff7666..8510624 100755 --- a/train.py +++ b/train.py @@ -14,9 +14,9 @@ nnet_batch = 10000 # (input)--> [Linear->Sigmoid] -> [Linear->Sigmoid] -->(output) # handle 10,000 vectors at a time Z1 = nnet.LinearLayer(input_shape=(rows * cols, nnet_batch), n_out=80) -A1 = nnet.SigmoidLayer(Z1.Z.shape) -ZO = nnet.LinearLayer(input_shape=A1.A.shape, n_out=10) -AO = nnet.SigmoidLayer(ZO.Z.shape) +A1 = nnet.SigmoidLayer(Z1.shape) +ZO = nnet.LinearLayer(input_shape=A1.shape, n_out=10) +AO = nnet.SigmoidLayer(ZO.shape) net = (Z1, A1, ZO, AO) res = nnet.forward(net, test_images[:, 0:10000])