如何使用 tf.repeat() 复制特定的 column/row/slice?

How to use tf.repeat() to replicate a specific column/row/slice?

线程很好地解释了 tf.repeat() 作为 np.repeat() 的 tensorflow 替代品的使用。我无法弄清楚的一个功能,在 np.repeat() 中,可以通过提供索引来复制特定的 column/row/slice。例如

import numpy as np
x = np.array([[1,2],[3,4]])
np.repeat(x, [1, 2], axis=0)
       # Answer will be -> array([[1, 2],
       #                          [3, 4],
       #                          [3, 4]])

是否有替代 np.repeat() 的此功能的 tensorflow?

您可以使用 tf.repeatrepeats 参数:

import tensorflow as tf

x = tf.constant([[1,2],[3,4]])
x = tf.repeat(x, repeats=[1, 2], axis=0)
print(x)
tf.Tensor(
[[1 2]
 [3 4]
 [3 4]], shape=(3, 2), dtype=int32)

你在张量中获得第一行一次,第二行两次。

或者您可以使用 tf.concattf.repeat:

import tensorflow as tf

x = tf.constant([[1,2],[3,4]])
x = tf.concat([x[:1], tf.repeat(x[1:], 2, axis=0)], axis=0)
print(x)

Tensorflow 1.14.0 解决方案:

import tensorflow as tf

x = tf.constant([[1,2],[3,4]])
x = tf.concat([x[:1], tf.tile(x[1:], multiples=[2, 1])], axis=0)
print(x)