使用 Tensorflow 中的索引为二维张量赋值
Assigning values to a 2D tensor using indices in Tensorflow
我有一个二维张量 A,我希望用另一个张量 B 替换它的非零项,如下所示。
A = tf.constant([[1.0,0,1.0],[0,1.0,0],[1.0,0,1.0]],dtype=tf.float32)
B = tf.constant([1.0,2.0,3.0,4,0,5.0],dtype=tf.float32)
所以我希望最后的 A 为
A = tf.constant([[1.0,0.0,2.0],[0,3.0,0.0],[4.0,0.0,5.0]],dtype=tf.float32)
然后我得到 A 的非零元素的索引如下
where_nonzero = tf.not_equal(A, tf.constant(0, dtype=tf.float32))
indices = tf.where(where_nonzero)
indices = <tf.Tensor: shape=(5, 2), dtype=int64, numpy=
array([[0, 0],
[0, 2],
[1, 1],
[2, 0],
[2, 2]])>
有人可以帮忙吗?
IIUC,你应该可以使用 tf.tensor_scatter_nd_update
:
import tensorflow as tf
A = tf.constant([[1.0,0,1.0],[0,1.0,0],[1.0,0,1.0]],dtype=tf.float32)
B = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0],dtype=tf.float32)
where_nonzero = tf.not_equal(A, tf.constant(0, dtype=tf.float32))
indices = tf.where(where_nonzero)
A = tf.tensor_scatter_nd_update(A, indices, B)
print(A)
tf.Tensor(
[[1. 0. 2.]
[0. 3. 0.]
[4. 0. 5.]], shape=(3, 3), dtype=float32)
你可以试试SparseTensor
c = tf.constant([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]])
indices = [[1, 1]] # A list of coordinates to update.
values = [1.0] # A list of values corresponding to the respective
# coordinate in indices.
shape = [3, 3] # The shape of the corresponding dense tensor, same as `c`.
delta = tf.SparseTensor(indices, values, shape)
或scatter_update
:
tf.scatter_update(c, indices, values)
我有一个二维张量 A,我希望用另一个张量 B 替换它的非零项,如下所示。
A = tf.constant([[1.0,0,1.0],[0,1.0,0],[1.0,0,1.0]],dtype=tf.float32)
B = tf.constant([1.0,2.0,3.0,4,0,5.0],dtype=tf.float32)
所以我希望最后的 A 为
A = tf.constant([[1.0,0.0,2.0],[0,3.0,0.0],[4.0,0.0,5.0]],dtype=tf.float32)
然后我得到 A 的非零元素的索引如下
where_nonzero = tf.not_equal(A, tf.constant(0, dtype=tf.float32))
indices = tf.where(where_nonzero)
indices = <tf.Tensor: shape=(5, 2), dtype=int64, numpy=
array([[0, 0],
[0, 2],
[1, 1],
[2, 0],
[2, 2]])>
有人可以帮忙吗?
IIUC,你应该可以使用 tf.tensor_scatter_nd_update
:
import tensorflow as tf
A = tf.constant([[1.0,0,1.0],[0,1.0,0],[1.0,0,1.0]],dtype=tf.float32)
B = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0],dtype=tf.float32)
where_nonzero = tf.not_equal(A, tf.constant(0, dtype=tf.float32))
indices = tf.where(where_nonzero)
A = tf.tensor_scatter_nd_update(A, indices, B)
print(A)
tf.Tensor(
[[1. 0. 2.]
[0. 3. 0.]
[4. 0. 5.]], shape=(3, 3), dtype=float32)
你可以试试SparseTensor
c = tf.constant([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]])
indices = [[1, 1]] # A list of coordinates to update.
values = [1.0] # A list of values corresponding to the respective
# coordinate in indices.
shape = [3, 3] # The shape of the corresponding dense tensor, same as `c`.
delta = tf.SparseTensor(indices, values, shape)
或scatter_update
:
tf.scatter_update(c, indices, values)