Theano 高效地按行切换

Theano switch row-wise efficiently

我有以下代码

output = T.switch(cond, a, b)

其中 cond 是一个 (N,1) 布尔张量,而 ab(N, M) 数值张量,其中 M 非常大.该条件以逐行方式运行。

我可以通过 运行 T.repeat()cond 上轻松地使开关工作,但这很慢。有没有一种方法可以有效地让 cond 中的布尔值决定是否应返回 ab

Is there a way I can efficiently make the bools in cond decide whether a or b should be returned?

是的,你可以

cond * a + (1-cond) * b

cond 将广播到 (N, M) 形状。

这应该接近理论极限,也就是内存带宽:这个操作需要读取约N*M个元素,写入N*M个元素。

相反,我们阅读 2*N*M,但删除了条件逻辑。

(我面前没有 Theano,所以我不确定它是否比 T.switch 快,但它应该尽可能好。另外,我会尝试转换condab)

相同的 dtype

如果您想就地更新 a,可以使用 T.set_subtensor:

a = np.random.uniform(size=(N, M)).astype(np.float32)
b = np.random.uniform(size=(N, M)).astype(np.float32)

a = theano.shared(a)
b = theano.shared(b)

c = T.vector() # mostly 0, presumably (1-cond)

nz = T.nonzero(c)

s = T.set_subtensor(a[nz], b[nz])
fn = theano.function([c], [], updates=[(a, s)])

...

fn(1-cond)

它可能比第一种方法快,也可能不快,具体取决于 NM 和其他因素。