将 tf.squeeze 替换为 tf.reshape

Replace tf.squeeze using as tf.reshape

我需要使用 tensorflow 训练移动网络。不支持 tf.squeeze 层。我可以用 tf.reshape 替换它吗?

是否操作:

tf.squeeze(net, [1, 2], name='squeeze')

等同于:

tf.reshape(net, [50,1000], name='reshape')

其中网的形状为 [50,1,1,1000]。

为什么说tf.squeeze不支持?为了从张量中移除一维轴,tf.squeeze 是正确的操作。但是你也可以用 tf.reshape 来完成你想要的工作,尽管我建议你使用 tf.squeeze.

tf 2.0 中,您可以轻松检查这些操作是否相同。唯一的区别是您可以使用 dim == 1 删除所有轴而不指定它们。所以在最后一行你可以使用 tf.squeeze(x_resh) 而不是 tf.squeeze(x_resh, [1, 2]).

size = [2, 3]
tf.random.set_seed(42)
x = tf.random.normal(size)
x
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258,  0.3194337],
       [-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>

x_resh = tf.reshape(x, [2, 1, 1, 3])
x_resh
<tf.Tensor: shape=(2, 1, 1, 3), dtype=float32, numpy=
array([[[[ 0.3274685, -0.8426258,  0.3194337]]],
       [[[-1.4075519, -2.3880599, -1.0392479]]]], dtype=float32)>

tf.reshape(x_resh, [2, 3])
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258,  0.3194337],
       [-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>

tf.squeeze(x_resh, [1, 2])
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258,  0.3194337],
       [-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>