Tensorflow:如何 slice/gather 所有可能的配置?

Tensorflow: How to slice/gather all possible configurations?

我有一个形状为 (batch size, sequence length, 2, N, K) 的张量(在我的特定情况下,2 表示 (x, y) 空间位置)。 N表示N个变量,K是每个变量可以取值的个数。

我想生成一个形状为 (batch size, sequence length, 2, N, K^N) 的张量,其中 K^N 来自每个 N“变量”的所有可能配置,每个 K =] 可能的值。

如何使用切片或收集在 Tensorflow 中高效地执行此操作?

让我举个形象的例子。出于说明的目的,我将省略批量大小和序列长度的 2 个主要维度。

假设这是一个 3D 张量 x,形状为 (2, N=3, K=4):

第一个配置是(稍微滥用符号)像 x[:, (0, 1, 2), (0, 0, 0)] 那样取一个切片;在这里,N=3 变量都取了它们的第一个值。第二种配置是像 x[:, (0, 1, 2), (0, 0, 1)] 这样的切片;在这里,前两个变量取第一个值,第三个变量取第二个值。继续下去,直到 4^3=64 种可能的配置,最后一种是 x[:, (0, 1, 2), (3, 3, 3)].

如果我将所有这些叠加起来,结果将是一个形状为 (2, 3, 4^3).

的张量

IIUC,如果它仍然相关,这里有一种方法可以完全使用 Tensorflow 解决您的问题(注意我使用 3D 张量并省略了前两个主要维度):

import tensorflow as tf

tf.random.set_seed(111)
x = tf.random.uniform((2, 3, 4), maxval=15, dtype=tf.int32)
x_shape = tf.shape(x)
print('x -->', x, '\n')

first_dim = tf.range(x_shape[0])
second_dim = tf.range(x_shape[1])
second_dim = tf.repeat(second_dim, tf.shape(first_dim)[0])
combination_range = tf.range(x_shape[-1])
xx, yy, zz = tf.meshgrid(combination_range, combination_range, combination_range, indexing='ij')
combinations = tf.stack([tf.reshape(xx, [-1]), tf.reshape(yy, [-1]), tf.reshape(zz, [-1])], axis=1)
print('combinations -->', combinations, '\n')

combinations = tf.reshape(tf.tile(combinations, [1, tf.shape(first_dim)[0]]), [-1])
first_dim = tf.tile(first_dim, [tf.shape(combinations)[0] // tf.shape(first_dim)[0]])
second_dim = tf.tile(second_dim, [tf.shape(combinations)[0] // tf.shape(second_dim)[0]])

result = tf.gather_nd(x, tf.transpose(tf.stack([first_dim, second_dim, combinations])))
result = tf.reshape(result, (x_shape[0], x_shape[1], x_shape[-1]**x_shape[1]))
print('final result -->', result)
x --> tf.Tensor(
[[[ 5 14  1 14]
  [ 2  1 12  3]
  [ 2  5  7 10]]

 [[ 0  9  0 12]
  [12 11  0  1]
  [ 2  6  1 12]]], shape=(2, 3, 4), dtype=int32) 

combinations --> tf.Tensor(
[[0 0 0]
 [0 0 1]
 [0 0 2]
 [0 0 3]
 [0 1 0]
 [0 1 1]
 [0 1 2]
 [0 1 3]
 [0 2 0]
 [0 2 1]
 [0 2 2]
 [0 2 3]
 [0 3 0]
 [0 3 1]
 [0 3 2]
 [0 3 3]
 [1 0 0]
 [1 0 1]
 [1 0 2]
 [1 0 3]
 [1 1 0]
 [1 1 1]
 [1 1 2]
 [1 1 3]
 [1 2 0]
 [1 2 1]
 [1 2 2]
 [1 2 3]
 [1 3 0]
 [1 3 1]
 [1 3 2]
 [1 3 3]
 [2 0 0]
 [2 0 1]
 [2 0 2]
 [2 0 3]
 [2 1 0]
 [2 1 1]
 [2 1 2]
 [2 1 3]
 [2 2 0]
 [2 2 1]
 [2 2 2]
 [2 2 3]
 [2 3 0]
 [2 3 1]
 [2 3 2]
 [2 3 3]
 [3 0 0]
 [3 0 1]
 [3 0 2]
 [3 0 3]
 [3 1 0]
 [3 1 1]
 [3 1 2]
 [3 1 3]
 [3 2 0]
 [3 2 1]
 [3 2 2]
 [3 2 3]
 [3 3 0]
 [3 3 1]
 [3 3 2]
 [3 3 3]], shape=(64, 3), dtype=int32) 

final result --> tf.Tensor(
[[[ 5  0  2 12  2  2  5  0  1 12  2  6  5  0 12 12  2  1  5  0  3 12  2
   12  5  9  2 12  5  2  5  9  1 12  5  6  5  9 12 12  5  1  5  9  3 12
    5 12  5  0  2 12  7  2  5  0  1 12  7  6  5  0 12 12]
  [ 7  1  5  0  3 12  7 12  5 12  2 12 10  2  5 12  1 12 10  6  5 12 12
   12 10  1  5 12  3 12 10 12 14  0  2 11  2  2 14  0  1 11  2  6 14  0
   12 11  2  1 14  0  3 11  2 12 14  9  2 11  5  2 14  9]
  [ 1 11  5  6 14  9 12 11  5  1 14  9  3 11  5 12 14  0  2 11  7  2 14
    0  1 11  7  6 14  0 12 11  7  1 14  0  3 11  7 12 14 12  2 11 10  2
   14 12  1 11 10  6 14 12 12 11 10  1 14 12  3 11 10 12]]

 [[ 1  0  2  0  2  2  1  0  1  0  2  6  1  0 12  0  2  1  1  0  3  0  2
   12  1  9  2  0  5  2  1  9  1  0  5  6  1  9 12  0  5  1  1  9  3  0
    5 12  1  0  2  0  7  2  1  0  1  0  7  6  1  0 12  0]
  [ 7  1  1  0  3  0  7 12  1 12  2  0 10  2  1 12  1  0 10  6  1 12 12
    0 10  1  1 12  3  0 10 12 14  0  2  1  2  2 14  0  1  1  2  6 14  0
   12  1  2  1 14  0  3  1  2 12 14  9  2  1  5  2 14  9]
  [ 1  1  5  6 14  9 12  1  5  1 14  9  3  1  5 12 14  0  2  1  7  2 14
    0  1  1  7  6 14  0 12  1  7  1 14  0  3  1  7 12 14 12  2  1 10  2
   14 12  1  1 10  6 14 12 12  1 10  1 14 12  3  1 10 12]]], shape=(2, 3, 64), dtype=int32)

对于任意 N,试试这个:

import tensorflow as tf

tf.random.set_seed(111)
x = tf.random.uniform((2, 6, 4), maxval=15, dtype=tf.int32)
x_shape = tf.shape(x)
print('x -->', x, '\n')

first_dim = tf.range(x_shape[0])
second_dim = tf.range(x_shape[1])
second_dim = tf.repeat(second_dim, tf.shape(first_dim)[0])
combination_range = tf.range(x_shape[-1])
outputs = tf.meshgrid(*tuple(tf.unstack(tf.repeat(combination_range[tf.newaxis, ...], x_shape[1], axis=0), axis=0)), indexing='ij')
combinations = tf.stack([tf.reshape(o, [-1]) for o in outputs], axis=1)
print('combinations -->', combinations, '\n')

combinations = tf.reshape(tf.tile(combinations, [1, tf.shape(first_dim)[0]]), [-1])
first_dim = tf.tile(first_dim, [tf.shape(combinations)[0] // tf.shape(first_dim)[0]])
second_dim = tf.tile(second_dim, [tf.shape(combinations)[0] // tf.shape(second_dim)[0]])

result = tf.gather_nd(x, tf.transpose(tf.stack([first_dim, second_dim, combinations])))
result = tf.reshape(result, (x_shape[0], x_shape[1], x_shape[-1]**x_shape[1]))
print('final result -->', result)