Theano 扫描:如何将元组输出作为下一步的输入?

Theano scan: how do I have a tuple output as input into next step?

我想在 Theano 中做这种循环:

def add_multiply(a,b, k):
    return a+b+k, a*b*k, k

x=1
y=2
k=1
tuples = []
for i in range(5):
    x,y,k = add_multiply(x,y,k)
    tuples.append((x,y,k))

然而,当我这样做时

x0 = T.dvector('x0')
i = T.iscalar('i')
results,updates=th.scan(fn=add_multiply,outputs_info=[{'initial':x0,'taps':[-1]}],n_steps=i)

我得到 TypeError: add_multiply() takes exactly 3 arguments (1 given)。如果我更改它以便函数采用单个元组代替,I get ValueError: length not known

特别是,我最终想根据 k 对整个结果进行微分。

第一个错误是因为您的 add_multiply 函数有 3 个参数,但是由于 outputs_info 列表中只有一个元素,您只提供了一个参数。目前尚不清楚您是希望 x0 向量仅作为 a 的初始值,还是希望它分布在 ab 和 [=17= 上]. Theano 不支持后者,而且通常 Theano 不支持元组。在 Theano 中,一切都必须是张量(例如,标量只是零维张量的特殊类型)。

您可以在 Theano 中实现 Python 实现的副本,如下所示。

import theano
import theano.tensor as tt


def add_multiply(a, b, k):
    return a + b + k, a * b * k


def python_main():
    x = 1
    y = 2
    k = 1
    tuples = []
    for i in range(5):
        x, y = add_multiply(x, y, k)
        tuples.append((x, y, k))
    return tuples


def theano_main():
    x = tt.constant(1, dtype='uint32')
    y = tt.constant(2, dtype='uint32')
    k = tt.scalar(dtype='uint32')
    outputs, _ = theano.scan(add_multiply, outputs_info=[x, y], non_sequences=[k], n_steps=5)
    g = theano.grad(tt.sum(outputs), k)
    f = theano.function(inputs=[k], outputs=outputs + [g])
    tuples = []
    xvs, yvs, _ = f(1)
    for xv, yv in zip(xvs, yvs):
        tuples.append((xv, yv, 1))
    return tuples


print 'Python:', python_main()
print 'Theano:', theano_main()

请注意,在 Theano 版本中,所有元组处理都发生在 Theano 之外; Python 必须将 Theano 函数返回的三个张量转换为元组列表。

更新:

不清楚 "the entire result" 应该指的是什么,但代码已更新以显示您可以如何区分 k。请注意,在 Theano 中,符号微分仅适用于标量表达式,但可以针对多维张量进行微分。

在此更新中,add_multiply 方法不再 returns k,因为它是不变的。出于类似的原因,Theano 版本现在接受 k 作为 non_sequence.