TensorFlow 2.0 分散添加

TensorFlow 2.0 scatter add

我想在 TensorFlow 2.0 中实现以下设计。

给定一个 memory 张量,形状为 [a, b, c]
形状为 [a, 1]
indices 张量 和形状为 [a, c]

updates 张量

我想在 indices 指示的位置用 updates.

中的值递增 memory

tf.tensor_scatter_nd_add 好像不行:

tf.tensor_scatter_nd_add(memory, indices, updates) returns {InvalidArgumentError}Inner dimensions of output shape must match inner dimensions of updates shape. Output: [a,b,c] updates: [a,c] [Op:TensorScatterAdd].

updates 真的有必要拥有和 memory 一样多的内部维度吗?在我的逻辑中,memory[indices](作为伪代码)应该已经是一个形状为 [a, c] 的张量。此外, tf.gather_nd(params=memory, indices=indices, batch_dims=1) 的形状已经是 [a, c].

你能推荐一个替代品吗?

谢谢。

我想你想要的是这个:

import tensorflow as tf

a, b, c = 3, 4, 5
memory = tf.ones([a, b, c])
indices = tf.constant([[2], [0], [3]])
updates = 10 * tf.reshape(tf.range(a * c, dtype=memory.dtype), [a, c])
print(updates.numpy())
# [[  0.  10.  20.  30.  40.]
#  [ 50.  60.  70.  80.  90.]
#  [100. 110. 120. 130. 140.]]

# Make indices for first dimension
ind_a = tf.range(tf.shape(indices, out_type=indices.dtype)[0])
# Make full indices
indices_2 = tf.concat([tf.expand_dims(ind_a, 1), indices], axis=1)
# Scatter add
out = tf.tensor_scatter_nd_add(memory, indices_2, updates)
print(out.numpy())
# [[[  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [  1.  11.  21.  31.  41.]
#   [  1.   1.   1.   1.   1.]]
# 
#  [[ 51.  61.  71.  81.  91.]
#   [  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]]
# 
#  [[  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [  1.   1.   1.   1.   1.]
#   [101. 111. 121. 131. 141.]]]