tf.tensor_scatter_nd_add 这可能吗

Is this possible with tf.tensor_scatter_nd_add

以下使用 tf.tensor_scatter_nd_add 的简单示例给我带来了问题。

B = tf.tensor_scatter_nd_add(A, indices, updates)

张量 A 为 (1,4,4)

A = [[[1. 1. 1. 1.],
      [1. 1. 1. 1.],
      [1. 1. 1. 1.],
      [1. 1. 1. 1.]]]

期望的结果是张量 B:

B = [[[1. 1. 1. 1.],
      [1. 2. 3. 1.],
      [1. 4. 5. 1.],
      [1. 1. 1. 1.]]]

即我想将这个较小的张量添加到张量 A

的 4 个内部元素中
updates = [[[1, 2],
            [3, 4]]]

张量流 2.1.0。我尝试了多种构建索引的方法。调用 tensor_scatter_nd_add returns 一个错误,指出内部尺寸不匹配。

更新张量是否需要与 A 形状相同?

涡虫,

尝试按以下方式传递索引和更新:更新形状为 (n),索引形状为 (n,3),其中 n 是更改项的数量。 索引应指向您要更改的单个单元格:

A = tf.ones((1,4,4,), dtype=tf.dtypes.float32)
updates =  tf.constant([1., 2., 3., 4])
indices = tf.constant([[0,1,1], [0,1,2], [0,2,1], [0,2,2]])
tf.tensor_scatter_nd_add(A, indices, updates)

<tf.Tensor: shape=(1, 4, 4), dtype=float32, numpy=
array([[[1., 1., 1., 1.],
        [1., 2., 3., 1.],
        [1., 4., 5., 1.],
        [1., 1., 1., 1.]]], dtype=float32)>