tf.tensor_scatter_nd_add 功能有效

tf.tensor_scatter_nd_add function did work

我想在具有两个新值矩阵的张量的第一个维度中插入两个切片,我正在使用方法 tensor_scatter_add 但它给我一个错误

indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]],
                       [[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]]])
tensor = tf.ones([4, 5, 4])
updated = tf.tensor_scatter_add(tensor, indices, updates)
with tf.Session() as se:
  print(ses.run(scatter))

tensor 的内部 2 个维度必须与 updates 的内部 2 个维度匹配。 Dimension 0 in both shapes must be equal, but are 5 and 4.

tensor 必须与 updates 相同 dtype 但在您的代码中不同。

有错误:

with tf.Session() as se:
  print(ses.run(scatter))

你将 tf.Session() 别名为 se 但调用 ses 而不是 se 并且你的传递分散到 ses.run() 但它没有在任何地方定义; se.run(updated) 应该是正确的函数调用。

带有代码修复的代码段:
这应该适合你。

indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]],
                       [[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]]])
tensor = tf.ones([4, 4, 4], dtype=tf.int32)
updated = tf.tensor_scatter_nd_add(tensor, indices, updates)
with tf.Session() as se:
  print(se.run(updated))

只需更正这些行,您输入错误,这会导致您的代码出现问题:

tensor = tf.ones([4, 4, 4])
updated = tf.tensor_scatter_add(tensor, indices, updates)
with tf.Session() as se:
  print(se.run(scatter))