将 tf.squeeze 替换为 tf.reshape
Replace tf.squeeze using as tf.reshape
我需要使用 tensorflow 训练移动网络。不支持 tf.squeeze 层。我可以用 tf.reshape 替换它吗?
是否操作:
tf.squeeze(net, [1, 2], name='squeeze')
等同于:
tf.reshape(net, [50,1000], name='reshape')
其中网的形状为 [50,1,1,1000]。
为什么说tf.squeeze不支持?为了从张量中移除一维轴,tf.squeeze
是正确的操作。但是你也可以用 tf.reshape
来完成你想要的工作,尽管我建议你使用 tf.squeeze
.
在 tf 2.0
中,您可以轻松检查这些操作是否相同。唯一的区别是您可以使用 dim == 1
删除所有轴而不指定它们。所以在最后一行你可以使用 tf.squeeze(x_resh)
而不是 tf.squeeze(x_resh, [1, 2])
.
size = [2, 3]
tf.random.set_seed(42)
x = tf.random.normal(size)
x
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258, 0.3194337],
[-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>
x_resh = tf.reshape(x, [2, 1, 1, 3])
x_resh
<tf.Tensor: shape=(2, 1, 1, 3), dtype=float32, numpy=
array([[[[ 0.3274685, -0.8426258, 0.3194337]]],
[[[-1.4075519, -2.3880599, -1.0392479]]]], dtype=float32)>
tf.reshape(x_resh, [2, 3])
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258, 0.3194337],
[-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>
tf.squeeze(x_resh, [1, 2])
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258, 0.3194337],
[-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>
我需要使用 tensorflow 训练移动网络。不支持 tf.squeeze 层。我可以用 tf.reshape 替换它吗?
是否操作:
tf.squeeze(net, [1, 2], name='squeeze')
等同于:
tf.reshape(net, [50,1000], name='reshape')
其中网的形状为 [50,1,1,1000]。
为什么说tf.squeeze不支持?为了从张量中移除一维轴,tf.squeeze
是正确的操作。但是你也可以用 tf.reshape
来完成你想要的工作,尽管我建议你使用 tf.squeeze
.
在 tf 2.0
中,您可以轻松检查这些操作是否相同。唯一的区别是您可以使用 dim == 1
删除所有轴而不指定它们。所以在最后一行你可以使用 tf.squeeze(x_resh)
而不是 tf.squeeze(x_resh, [1, 2])
.
size = [2, 3]
tf.random.set_seed(42)
x = tf.random.normal(size)
x
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258, 0.3194337],
[-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>
x_resh = tf.reshape(x, [2, 1, 1, 3])
x_resh
<tf.Tensor: shape=(2, 1, 1, 3), dtype=float32, numpy=
array([[[[ 0.3274685, -0.8426258, 0.3194337]]],
[[[-1.4075519, -2.3880599, -1.0392479]]]], dtype=float32)>
tf.reshape(x_resh, [2, 3])
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258, 0.3194337],
[-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>
tf.squeeze(x_resh, [1, 2])
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258, 0.3194337],
[-1.4075519, -2.3880599, -1.0392479]], dtype=float32)>