如何在 Theano 中进行逐元素条件索引比较?

How to do element-wise conditional indexing comparison in Theano?

操作由两个长度相等的数组Xidx组成,其中idx的值可以在0到(k-1)之间变化,k的值给定.

这是用于说明这一点的通用 Python 代码。

import numpy as np

X = np.arange(6) # Just for a sample of elements
k = 3
idx = numpy.array([[0, 1, 2, 2, 0, 1]]).T # Can only contain values in [0..(k-1)]
np.array([X[np.where(idx==i)[0]] for i in range(k)])

示例输出:

array([[0, 4],
       [1, 5],
       [2, 3]])

请注意,我实际上有理由将 idx 表示为矩阵而不是向量。作为其计算的一部分,它被初始化为 numpy.zeros((n,1)),其中 nX 的大小。

我试过像这样在 Theano 中实现这个

import theano
import theano.tensor as T

X = T.vector('X')
idx = T.vector('idx')
k = T.scalar()
c = theano.scan(lambda i: X[T.where(T.eq(idx,i))], sequences=T.arange(k)) 
f = function([X,idx,k],c)

但我在定义 c 的行收到此错误:

TypeError: Wrong number of inputs for Switch.make_node (got 1((<int8>,)), expected 3)

有没有在 Theano 中实现这个的简单方法?

使用nonzero()并更正idx的尺寸。

这段代码解决了问题

import theano
import theano.tensor as T

X = T.vector('X')
idx = T.vector('idx')
k = T.scalar()
c, updates = theano.scan(lambda i: X[T.eq(idx,i).nonzero()], sequences=T.arange(k)) 
f = function([X,idx,k],c)

对于同一个例子,通过使用 Theano:

import numpy as np

X = np.arange(6) 
k = 3
idx = np.array([[0, 1, 2, 2, 0, 1]]).T

f(X, idx.T[0], k).astype(int)

这给出了输出

array([[0, 4],
       [1, 5],
       [2, 3]])

如果idx定义为np.array([0, 1, 2, 2, 0, 1]),那么可以用f(X, idx, k)代替。