从 tf.tensor 的一个操作中提取多列

extract several columns in one op from tf.tensor

在最近的 TensorFlow(1.132.0)中,有没有一种方法可以一次从张量中提取不连续的切片?怎么做? 例如使用以下张量:

1 2 3 4
5 6 7 8 

我想在一个操作中提取第 1 列和第 3 列以获得:

2 4
6 8

然而,我似乎无法通过切片在单个操作中完成。 correct/fastest/most 优雅的方法是什么?

您可以通过整形和切片的组合获得所有奇数列:

N = 4
M = 10
input = tf.constant(np.random.rand(M, N))
slice_odd = tf.reshape(tf.reshape(input, (-1, 2))[:,1], (-1, int(N/2)))

1. 使用 tf.gather(tensor, columns, axis=1) (TF1.x, TF2):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3]

print(tf.gather(tensor, columns, axis=1).numpy())
%timeit -n 10000 tf.gather(tensor, columns, axis=1)
# [[2. 4.]
#  [6. 8.]]
82.6 µs ± 5.76 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

2. 带索引 (TF1.x, TF2):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract

transposed = tf.transpose(tensor)
sliced = [transposed[c] for c in columns]
stacked = tf.transpose(tf.stack(sliced, axis=0))
# print(stacked.numpy()) # <-- TF2, TF1.x-eager

with tf.Session() as sess:  # <-- TF1.x
    print(sess.run(stacked))
# [[2. 4.]
#  [6. 8.]]

将它包装成一个函数并 运行 %timeit in tf.__version__=='2.0.0-alpha0':

154 µs ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

@tf.function装饰它快2倍多:

import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract
@tf.function
def extract_columns(tensor=tensor, columns=columns):
    transposed = tf.transpose(tensor)
    sliced = [transposed[c] for c in columns]
    stacked = tf.transpose(tf.stack(sliced, axis=0))
    return stacked

%timeit -n 10000 extract_columns()
66.8 µs ± 2.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

3. eager execution (TF2, TF1.x-eager):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract

res = tf.transpose(tf.stack([t for i, t in enumerate(tf.transpose(tensor))
                             if i in columns], 0))
print(res.numpy())
# [[2. 4.]
#  [6. 8.]]

%timeittf.__version__=='2.0.0-alpha0':

242 µs ± 2.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

4.使用tf.one_hot()指定rows/columns然后tf.boolean_mask()提取这些rows/columns(TF1.x, TF2):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract

mask = tf.one_hot(columns, tensor.get_shape().as_list()[-1])
mask = tf.reduce_sum(mask, axis=0)
res = tf.transpose(tf.boolean_mask(tf.transpose(tensor), mask))
# print(res.numpy()) # <-- TF2, TF1.x-eager

with tf.Session() as sess: # TF1.x
    print(sess.run(res))
# [[2. 4.]
#  [6. 8.]]

%timeittf.__version__=='2.0.0-alpha0':

494 µs ± 4.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)