如何使用 Theano 将列中的最大值替换为 1.0?

How to replace maximal values in columns by 1.0 using Theano?

我在 Theano 中有一个实值矩阵,我想生成另一个矩阵,这样在新矩阵的每一列中我都有一个 1.0 和 0.0 否则。 1.0 应该表示最大值在输入矩阵列中的位置。

例如。如果使用下面的矩阵作为输入

1.0   2.0   3.0   5.0
2.1   0.0   4.0   0.0
0.0   3.0   1.0   4.0

必须生成以下矩阵作为输出:

0.0   0.0   0.0   1.0
1.0   0.0   1.0   0.0
0.0   1.0   0.0   0.0

我目前使用的解决方案如下:

tmp = T.max(inp, axis = 0).dimshuffle('x',0)
out = T.switch(T.eq(tmp, inp), 1.0, 0.0)

这个解决方案似乎可行,但我不确定它有多稳健。主要关注的是,我比较当前值是否 正好 等于列中的最大值。会不会因为某些 "rounding" 错误,最大值不会被我们这样识别?

我会使用 argmax 函数而不是 max & dimshuffle 方法。 Argmax returns 最大值的索引,而不是值本身。

tmp = T.argmax(inp, axis = 0)

然后您可以用全零初始化一个矩阵,并使用您的 tmp 数组将所需的索引设置为 1.0(我现在无法 test/provide 为这部分编写代码,但它应该是微不足道的)

试试这个:

out = T.eye(3)[T.argmax(inp, axis=0)].T  # replace 3 with number of rows

这将输出:

0.0   0.0   0.0   1.0
1.0   0.0   1.0   0.0
0.0   1.0   0.0   0.0