TensorFlow 2.0 分散添加
TensorFlow 2.0 scatter add
我想在 TensorFlow 2.0 中实现以下设计。
给定一个 memory
张量,形状为 [a, b, c]
,
形状为 [a, 1]
、
的 indices
张量
和形状为 [a, c]
、
的 updates
张量
我想在 indices
指示的位置用 updates
.
中的值递增 memory
tf.tensor_scatter_nd_add
好像不行:
tf.tensor_scatter_nd_add(memory, indices, updates)
returns {InvalidArgumentError}Inner dimensions of output shape must match inner dimensions of updates shape. Output: [a,b,c] updates: [a,c] [Op:TensorScatterAdd]
.
updates
真的有必要拥有和 memory
一样多的内部维度吗?在我的逻辑中,memory[indices]
(作为伪代码)应该已经是一个形状为 [a, c]
的张量。此外, tf.gather_nd(params=memory, indices=indices, batch_dims=1)
的形状已经是 [a, c]
.
你能推荐一个替代品吗?
谢谢。
我想你想要的是这个:
import tensorflow as tf
a, b, c = 3, 4, 5
memory = tf.ones([a, b, c])
indices = tf.constant([[2], [0], [3]])
updates = 10 * tf.reshape(tf.range(a * c, dtype=memory.dtype), [a, c])
print(updates.numpy())
# [[ 0. 10. 20. 30. 40.]
# [ 50. 60. 70. 80. 90.]
# [100. 110. 120. 130. 140.]]
# Make indices for first dimension
ind_a = tf.range(tf.shape(indices, out_type=indices.dtype)[0])
# Make full indices
indices_2 = tf.concat([tf.expand_dims(ind_a, 1), indices], axis=1)
# Scatter add
out = tf.tensor_scatter_nd_add(memory, indices_2, updates)
print(out.numpy())
# [[[ 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 1. 1.]
# [ 1. 11. 21. 31. 41.]
# [ 1. 1. 1. 1. 1.]]
#
# [[ 51. 61. 71. 81. 91.]
# [ 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 1. 1.]]
#
# [[ 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 1. 1.]
# [101. 111. 121. 131. 141.]]]
我想在 TensorFlow 2.0 中实现以下设计。
给定一个 memory
张量,形状为 [a, b, c]
,
形状为 [a, 1]
、
的 indices
张量
和形状为 [a, c]
、
updates
张量
我想在 indices
指示的位置用 updates
.
memory
tf.tensor_scatter_nd_add
好像不行:
tf.tensor_scatter_nd_add(memory, indices, updates)
returns {InvalidArgumentError}Inner dimensions of output shape must match inner dimensions of updates shape. Output: [a,b,c] updates: [a,c] [Op:TensorScatterAdd]
.
updates
真的有必要拥有和 memory
一样多的内部维度吗?在我的逻辑中,memory[indices]
(作为伪代码)应该已经是一个形状为 [a, c]
的张量。此外, tf.gather_nd(params=memory, indices=indices, batch_dims=1)
的形状已经是 [a, c]
.
你能推荐一个替代品吗?
谢谢。
我想你想要的是这个:
import tensorflow as tf
a, b, c = 3, 4, 5
memory = tf.ones([a, b, c])
indices = tf.constant([[2], [0], [3]])
updates = 10 * tf.reshape(tf.range(a * c, dtype=memory.dtype), [a, c])
print(updates.numpy())
# [[ 0. 10. 20. 30. 40.]
# [ 50. 60. 70. 80. 90.]
# [100. 110. 120. 130. 140.]]
# Make indices for first dimension
ind_a = tf.range(tf.shape(indices, out_type=indices.dtype)[0])
# Make full indices
indices_2 = tf.concat([tf.expand_dims(ind_a, 1), indices], axis=1)
# Scatter add
out = tf.tensor_scatter_nd_add(memory, indices_2, updates)
print(out.numpy())
# [[[ 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 1. 1.]
# [ 1. 11. 21. 31. 41.]
# [ 1. 1. 1. 1. 1.]]
#
# [[ 51. 61. 71. 81. 91.]
# [ 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 1. 1.]]
#
# [[ 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 1. 1.]
# [ 1. 1. 1. 1. 1.]
# [101. 111. 121. 131. 141.]]]