argmax 结果作为子张量
argmax result as a subtensor
我想使用保持尺寸的 argmax 作为子张量。我有:
m, argm = T.max_and_argmax(a, axis=axis, keepdims=True)
我想在 a
中将这些值设置为零。 IE。我需要使用 T.set_subtensor
。要使用它,我需要在 argm
处指定 a
的子张量 a_sub
,但我不确定它看起来像什么。 a_sub = a[argm]
多维度错误。
这应该成立:
a_sub == T.max(a, axis=axis)
a_sub.shape == T.max(a, axis=axis).shape
最后,我想做的是:
a = T.set_subtensor(a_sub, 0)
我目前的解决方案:
idx = T.arange(a.shape[axis]).dimshuffle(['x'] * axis + [0] + ['x'] * (a.ndim - axis - 1))
a = T.switch(T.eq(idx, argm), 0, a)
但是,a_sub = a[T.eq(idx, argm)]
不起作用。
你需要使用Theano的advanced indexing features which, unfortunately, differ from numpy's advanced indexing。
这是一个可以满足您要求的示例。
更新: 现在可以使用参数化轴,但请注意 axis
不能是符号化的。
import numpy
import theano
import theano.tensor as tt
theano.config.compute_test_value = 'raise'
axis = 2
x = tt.tensor3()
x.tag.test_value = numpy.array([[[3, 2, 6], [5, 1, 4]], [[2, 1, 6], [6, 1, 5]]],
dtype=theano.config.floatX)
# Identify the largest value in each row
x_argmax = tt.argmax(x, axis=axis, keepdims=True)
# Construct a row of indexes to the length of axis
indexes = tt.arange(x.shape[axis]).dimshuffle(
*(['x' for dim1 in xrange(axis)] + [0] + ['x' for dim2 in xrange(x.ndim - axis - 1)]))
# Create a binary mask indicating where the maximum values appear
mask = tt.eq(indexes, x_argmax)
# Alter the original matrix only at the places where the maximum values appeared
x_prime = tt.set_subtensor(x[mask.nonzero()], 0)
print x_prime.tag.test_value
我想使用保持尺寸的 argmax 作为子张量。我有:
m, argm = T.max_and_argmax(a, axis=axis, keepdims=True)
我想在 a
中将这些值设置为零。 IE。我需要使用 T.set_subtensor
。要使用它,我需要在 argm
处指定 a
的子张量 a_sub
,但我不确定它看起来像什么。 a_sub = a[argm]
多维度错误。
这应该成立:
a_sub == T.max(a, axis=axis)
a_sub.shape == T.max(a, axis=axis).shape
最后,我想做的是:
a = T.set_subtensor(a_sub, 0)
我目前的解决方案:
idx = T.arange(a.shape[axis]).dimshuffle(['x'] * axis + [0] + ['x'] * (a.ndim - axis - 1))
a = T.switch(T.eq(idx, argm), 0, a)
但是,a_sub = a[T.eq(idx, argm)]
不起作用。
你需要使用Theano的advanced indexing features which, unfortunately, differ from numpy's advanced indexing。
这是一个可以满足您要求的示例。
更新: 现在可以使用参数化轴,但请注意 axis
不能是符号化的。
import numpy
import theano
import theano.tensor as tt
theano.config.compute_test_value = 'raise'
axis = 2
x = tt.tensor3()
x.tag.test_value = numpy.array([[[3, 2, 6], [5, 1, 4]], [[2, 1, 6], [6, 1, 5]]],
dtype=theano.config.floatX)
# Identify the largest value in each row
x_argmax = tt.argmax(x, axis=axis, keepdims=True)
# Construct a row of indexes to the length of axis
indexes = tt.arange(x.shape[axis]).dimshuffle(
*(['x' for dim1 in xrange(axis)] + [0] + ['x' for dim2 in xrange(x.ndim - axis - 1)]))
# Create a binary mask indicating where the maximum values appear
mask = tt.eq(indexes, x_argmax)
# Alter the original matrix only at the places where the maximum values appeared
x_prime = tt.set_subtensor(x[mask.nonzero()], 0)
print x_prime.tag.test_value