我正在使用 CIFAR-10 资料集来训练一些 MLP 模型。我想尝试将资料扩充作为下面的代码块。
learning_rate = 0.01
batch_size = 32
epoch = 50
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
# convert from integers to floats
train_images = train_images.astype('float32')
test_images = test_images.astype('float32')
# normalize to range 0-1
train_images = train_images / 255.0
test_images = test_images / 255.0
train_labels = keras.utils.to_categorical(train_labels, num_classes=10)
test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
augment = keras.preprocessing.image.ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True)
it_train = augment.flow(train_images, train_labels, batch_size=batch_size)
这是我使用的模型,你可以在下面看到。
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
model = models.Sequential()
model.add(layers.Dense(units=1000, activation=activation, input_dim=3072))
model.add(layers.BatchNormalization())
model.add(layers.Dropout(0.2))
model.add(layers.Dense(units=300, activation=activation))
model.add(layers.BatchNormalization())
model.add(layers.Dropout(0.2))
model.add(layers.Dense(units=100, activation=activation))
model.add(layers.BatchNormalization())
model.add(layers.Dropout(0.2))
model.add(layers.Dense(units=10, activation='softmax'))
这是我训练模型的路线。
history = model.fit(it_train, steps_per_epoch=len(train_images), epochs=epoch, validation_data=(test_images, test_labels))
但是,我收到此错误。CIFAR10 资料集为 32x32x3,包含 10 个标签。
ValueError: Input 0 of layer batch_normalization is incompatible with the layer: expected ndim=2, found ndim=4. Full shape received: (None, None, None, 1000)
我能做些什么来摆脱这个错误?
uj5u.com热心网友回复:
CIRFAR 的输入形状为 ( 32, 32, 3
),但您的模型的输入未采用该形状。对于模型输入,您可以尝试如下。
model = keras.Sequential()
# Before 1st dense layer adding a Flatten layer that will flat the
# coming tensor of shape (32, 32, 3).
model.add(keras.layers.Flatten(input_shape=(32, 32, 3)))
model.add(keras.layers.Dense(units=1000, activation=activation))
model.add(keras.layers.BatchNormalization())
...
0 评论