Tensorflow中某些元素的逆序
Reverse order of some elements in Tensorflow
假设我有一个形状为 (M, N, 2)
的张量 DATA
。
我还有另一个形状为 (N) 的张量 IND
,由 0 和 1 组成。
如果IND(i)==1
那么DATA(:,i,0)
和DATA(:,i,1)
必须交换。如果 IND(i)==0
他们不会交换。
我该怎么做?我知道这可以通过 tf.gather_nd
完成,但我不知道怎么做。
一种不使用tf.gather_ind的方法如下。这个想法是构建 DATA1,它是具有所有可能交换的 DATA(即,如果 IND 是 1 的向量,交换的结果),并根据是否需要交换使用掩码从 Data 或 Data1 中选择正确的值或不。
DATA1 = tf.concat([tf.reshape(DATA[:,:,1], [M, N, 1]), tf.reshape(DATA[:,:,0], [M, N, 1])], axis = 2)
Mask1 = tf.cast(tf.reshape(IND, [1, N, 1]), tf.float64)
Mask0 = 1 - Mask1
Res = tf.multiply(Mask0, DATA) + tf.multiply(Mask1, DATA1)
这是 tf.equal
, tf.where
, tf.scater_nd_update
, tf.gather_nd
and tf.reverse_v2
的一种可能解决方案:
data = tf.Variable([[[1, 2],
[2, 3],
[3, 4],
[4, 5],
[5, 6]]]) # shape=(1,5,2)
# reverse elements where ind is 1
ind = tf.constant([1, 0, 1, 0, 1]) # shape(5,)
cond = tf.where(tf.equal([ind], 1))
match_data = tf.gather_nd(data, cond)
rev_match_data = tf.reverse_v2(match_data, axis=[-1])
data = tf.scatter_nd_update(data, cond, rev_match_data)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(data))
#[[[2 1]
# [2 3]
# [4 3]
# [4 5]
# [6 5]]]
假设我有一个形状为 (M, N, 2)
的张量 DATA
。
我还有另一个形状为 (N) 的张量 IND
,由 0 和 1 组成。
如果IND(i)==1
那么DATA(:,i,0)
和DATA(:,i,1)
必须交换。如果 IND(i)==0
他们不会交换。
我该怎么做?我知道这可以通过 tf.gather_nd
完成,但我不知道怎么做。
一种不使用tf.gather_ind的方法如下。这个想法是构建 DATA1,它是具有所有可能交换的 DATA(即,如果 IND 是 1 的向量,交换的结果),并根据是否需要交换使用掩码从 Data 或 Data1 中选择正确的值或不。
DATA1 = tf.concat([tf.reshape(DATA[:,:,1], [M, N, 1]), tf.reshape(DATA[:,:,0], [M, N, 1])], axis = 2)
Mask1 = tf.cast(tf.reshape(IND, [1, N, 1]), tf.float64)
Mask0 = 1 - Mask1
Res = tf.multiply(Mask0, DATA) + tf.multiply(Mask1, DATA1)
这是 tf.equal
, tf.where
, tf.scater_nd_update
, tf.gather_nd
and tf.reverse_v2
的一种可能解决方案:
data = tf.Variable([[[1, 2],
[2, 3],
[3, 4],
[4, 5],
[5, 6]]]) # shape=(1,5,2)
# reverse elements where ind is 1
ind = tf.constant([1, 0, 1, 0, 1]) # shape(5,)
cond = tf.where(tf.equal([ind], 1))
match_data = tf.gather_nd(data, cond)
rev_match_data = tf.reverse_v2(match_data, axis=[-1])
data = tf.scatter_nd_update(data, cond, rev_match_data)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(data))
#[[[2 1]
# [2 3]
# [4 3]
# [4 5]
# [6 5]]]