相当于张量流中的 np.add.at
Equivalent for np.add.at in tensorflow
如何将 np.add.at 语句转换为 tensorflow?
np.add.at(dW, self.x.ravel(), dout.reshape(-1, self.D))
编辑
self.dW.shape是(V, D),self.D.shape是(N, D),self.x.size是N
对于np.add.at
,你可能想看看tf.SparseTensor,它用值列表和索引列表表示张量(更适合稀疏数据,因此得名).
所以对于你的例子:
np.add.at(dW, self.x.ravel(), dout.reshape(-1, self.D))
那将是(假设 dW
、x
和 dout
是张量):
tf.sparse_add(dW, tf.SparseTensor(x, tf.reshape(dout, [-1])))
这是假设 x
的形状为 [n, nDims]
(即 x
是 n 个索引的 'list',每个索引的维度为 nDims
),并且dout
的形状为 [n]
.
下面是 np.add.at
的一个例子:
In [324]: a=np.ones((10,))
In [325]: x=np.array([1,2,3,1,4,5])
In [326]: b=np.array([1,1,1,1,1,1])
In [327]: np.add.at(a,x,b)
In [328]: a
Out[328]: array([ 1., 3., 2., 2., 2., 2., 1., 1., 1., 1.])
如果我改用 +=
In [331]: a1=np.ones((10,))
In [332]: a1[x]+=b
In [333]: a1
Out[333]: array([ 1., 2., 2., 2., 2., 2., 1., 1., 1., 1.])
注意 a1[1]
是 2,不是 3。
如果我改用迭代解决方案
In [334]: a2=np.ones((10,))
In [335]: for i,j in zip(x,b):
...: a2[i]+=j
...:
In [336]: a2
Out[336]: array([ 1., 3., 2., 2., 2., 2., 1., 1., 1., 1.])
匹配。
如果 x
没有重复项,那么 +=
就可以正常工作。但是对于重复项,需要 add.at
来匹配迭代解决方案。
如何将 np.add.at 语句转换为 tensorflow?
np.add.at(dW, self.x.ravel(), dout.reshape(-1, self.D))
编辑
self.dW.shape是(V, D),self.D.shape是(N, D),self.x.size是N
对于np.add.at
,你可能想看看tf.SparseTensor,它用值列表和索引列表表示张量(更适合稀疏数据,因此得名).
所以对于你的例子:
np.add.at(dW, self.x.ravel(), dout.reshape(-1, self.D))
那将是(假设 dW
、x
和 dout
是张量):
tf.sparse_add(dW, tf.SparseTensor(x, tf.reshape(dout, [-1])))
这是假设 x
的形状为 [n, nDims]
(即 x
是 n 个索引的 'list',每个索引的维度为 nDims
),并且dout
的形状为 [n]
.
下面是 np.add.at
的一个例子:
In [324]: a=np.ones((10,))
In [325]: x=np.array([1,2,3,1,4,5])
In [326]: b=np.array([1,1,1,1,1,1])
In [327]: np.add.at(a,x,b)
In [328]: a
Out[328]: array([ 1., 3., 2., 2., 2., 2., 1., 1., 1., 1.])
如果我改用 +=
In [331]: a1=np.ones((10,))
In [332]: a1[x]+=b
In [333]: a1
Out[333]: array([ 1., 2., 2., 2., 2., 2., 1., 1., 1., 1.])
注意 a1[1]
是 2,不是 3。
如果我改用迭代解决方案
In [334]: a2=np.ones((10,))
In [335]: for i,j in zip(x,b):
...: a2[i]+=j
...:
In [336]: a2
Out[336]: array([ 1., 3., 2., 2., 2., 2., 1., 1., 1., 1.])
匹配。
如果 x
没有重复项,那么 +=
就可以正常工作。但是对于重复项,需要 add.at
来匹配迭代解决方案。