使用 K.tile() 复制张量

Copy tensor using K.tile()

我有张量(None, 196),整形后变成(None, 14, 14)。 现在,我想将它复制到通道轴,这样形状应该是 (None, 14, 14, 512)。最后,我想复制到时间轴,所以它变成了(None, 10, 14, 14, 512)。我使用此代码段完成了这些步骤:

def replicate(tensor, input_target):
  batch_size = K.shape(tensor)[0]
  nf, h, w, c = input_target
  x = K.reshape(tensor, [batch_size, 1, h, w, 1])

  # Replicate to channel dimension
  x = K.tile(x, [batch_size, 1, 1, 1, c])

  # Replicate to timesteps dimension
  x = K.tile(x, [batch_size, nf, 1, 1, 1])

  return x

x = ...
x = Lambda(replicate, arguments={'input_target':input_shape})(x)
another_x = Input(shape=input_shape) # shape (10, 14, 14, 512)

x = layers.multiply([x, another_x])
x = ...

我绘制了模型,输出形状就像我想要的那样。但是,问题出现在模型训练中。我将批量大小设置为 2。这是错误消息:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [8,10,14,14,512] vs. [2,10,14,14,512] [[{{node multiply_1/mul}} = Mul[T=DT_FLOAT, _class=["loc:@training/Adam/gradients/multiply_1/mul_grad/Sum"], _device="/job:localhost/replica:0/task:0/device:GPU:0"](Lambda_2/Tile_1, _arg_another_x_0_0/_189)]] [[{{node metrics/top_k_categorical_accuracy/Mean_1/_265}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_6346_metrics/top_k_categorical_accuracy/Mean_1", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

看起来,K.tile() 将批量大小从 2 增加到 8。当我将批量大小设置为 10 时,它变为 1000。

所以,我的问题是如何达到我想要的结果?使用 tile() 是个好方法吗?或者,我应该使用 repeat_elements() 吗?谢谢!

我正在使用 Tensorflow 1.12.0 和 Keras 2.2.4。

根据经验,尽量避免将批量大小引入 Lambda 层中发生的转换。

当您使用 tile 操作时,您只设置 需要更改的维度(例如您在平铺操作中有 batch_size 值这是错误的)。我还使用 tf.tile 而不是 K.tile (TF 1.12 似乎在 Keras 后端没有磁贴)。

def replicate(tensor, input_target):

  _, nf, h, w, c = input_target
  x = K.reshape(tensor, [-1, 1, h, w, 1])

  # Replicate to channel dimension
  # You can combine below lines to tf.tile(x, [1, nf, 1, 1, c]) as well
  x = tf.tile(x, [1, 1, 1, 1, c])
  # Replicate to timesteps dimension
  x = tf.tile(x, [1, nf, 1, 1, 1])


  return x

简单示例

input_shape= [None, 10, 14, 14, 512]
x = Input(shape=(196,))
x = Lambda(replicate, arguments={'input_target':input_shape})(x)
print(x.shape)

给出

>>> (?, 10, 14, 14, 512)