在 TensorFlow/Keras 中与批处理轴一起展平

Flatten alongside with batch axis in TensorFlow / Keras

在顺序模型中,我试图从 (None, 300) 的图层输出形状变为 (1,1,None*300) 之类的形状以应用 AveragePooling 图层。事实上,我想展平一切(甚至批次轴),而 FlattenReshape 层总是跳过批次轴。有什么想法吗?

您可以像这样从后端使用 Lambda 层和 K.reshape

from keras import backend as K

out = Lambda(lambda x: K.reshape(x, (1, 1, -1)))(inp)