diff --git a/mmodel/theano/THEANO.py b/mmodel/theano/THEANO.py index f2af954..5231128 100644 --- a/mmodel/theano/THEANO.py +++ b/mmodel/theano/THEANO.py @@ -48,7 +48,7 @@ class ModelTHEANO(ModelBase): X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) train_set_x, train_set_y = shared_dataset((X_train, Y_train)) - valid_set_x, valid_set_y = train_set_x[:1000], train_set_y[:1000] + valid_set_x, valid_set_y = shared_dataset((X_train[:1000], Y_train[:1000])) test_set_x, test_set_y = shared_dataset((X_test, Y_test)) # compute number of minibatches for training, validation and testing @@ -89,8 +89,8 @@ class ModelTHEANO(ModelBase): # Construct the second convolutional pooling layer # filtering reduces the image size to (148-5+1, 148-5+1) = (144, 144) - # maxpooling reduces this further to (144/4, 144/4) = (38, 38) - # 4D output tensor is thus of shape (batch_size, nkerns[1], 38, 38) + # maxpooling reduces this further to (144/4, 144/4) = (36, 36) + # 4D output tensor is thus of shape (batch_size, nkerns[1], 36, 36) layer1 = ConvPoolLayer( rng, input=layer0.output, @@ -101,15 +101,15 @@ class ModelTHEANO(ModelBase): # the HiddenLayer being fully-connected, it operates on 2D matrices of # shape (batch_size, num_pixels) (i.e matrix of rasterized images). - # This will generate a matrix of shape (batch_size, nkerns[1] * 4 * 4), - # or (500, 50 * 4 * 4) = (500, 800) with the default values. + # This will generate a matrix of shape (batch_size, nkerns[1] * 36 * 36), + # or (500, 50 * 36 * 36) = (500, 800) with the default values. layer2_input = layer1.output.flatten(2) # construct a fully-connected sigmoidal layer layer2 = HiddenLayer( rng, input=layer2_input, - n_in=nkerns[1] * 38 * 38, + n_in=nkerns[1] * 36 * 36, n_out=500, activation=T.tanh ) @@ -156,8 +156,8 @@ class ModelTHEANO(ModelBase): ] """ Total Parameters: - >>> 20 * 64 + 1000 * 25 + 50 * 38 * 38 * 500 + 500 * 2 - 36127280 + >>> 20 * 64 + 1000 * 25 + 50 * 36 * 36 * 500 + 500 * 2 + 32427280 """ train_model = theano.function( [index], -- libgit2 0.21.2