Theano 扫描:如何将元组输出作为下一步的输入?
Theano scan: how do I have a tuple output as input into next step?
我想在 Theano 中做这种循环:
def add_multiply(a,b, k):
return a+b+k, a*b*k, k
x=1
y=2
k=1
tuples = []
for i in range(5):
x,y,k = add_multiply(x,y,k)
tuples.append((x,y,k))
然而,当我这样做时
x0 = T.dvector('x0')
i = T.iscalar('i')
results,updates=th.scan(fn=add_multiply,outputs_info=[{'initial':x0,'taps':[-1]}],n_steps=i)
我得到 TypeError: add_multiply() takes exactly 3 arguments (1 given)
。如果我更改它以便函数采用单个元组代替,I get ValueError: length not known
特别是,我最终想根据 k 对整个结果进行微分。
第一个错误是因为您的 add_multiply
函数有 3 个参数,但是由于 outputs_info
列表中只有一个元素,您只提供了一个参数。目前尚不清楚您是希望 x0
向量仅作为 a
的初始值,还是希望它分布在 a
、b
和 [=17= 上]. Theano 不支持后者,而且通常 Theano 不支持元组。在 Theano 中,一切都必须是张量(例如,标量只是零维张量的特殊类型)。
您可以在 Theano 中实现 Python 实现的副本,如下所示。
import theano
import theano.tensor as tt
def add_multiply(a, b, k):
return a + b + k, a * b * k
def python_main():
x = 1
y = 2
k = 1
tuples = []
for i in range(5):
x, y = add_multiply(x, y, k)
tuples.append((x, y, k))
return tuples
def theano_main():
x = tt.constant(1, dtype='uint32')
y = tt.constant(2, dtype='uint32')
k = tt.scalar(dtype='uint32')
outputs, _ = theano.scan(add_multiply, outputs_info=[x, y], non_sequences=[k], n_steps=5)
g = theano.grad(tt.sum(outputs), k)
f = theano.function(inputs=[k], outputs=outputs + [g])
tuples = []
xvs, yvs, _ = f(1)
for xv, yv in zip(xvs, yvs):
tuples.append((xv, yv, 1))
return tuples
print 'Python:', python_main()
print 'Theano:', theano_main()
请注意,在 Theano 版本中,所有元组处理都发生在 Theano 之外; Python 必须将 Theano 函数返回的三个张量转换为元组列表。
更新:
不清楚 "the entire result" 应该指的是什么,但代码已更新以显示您可以如何区分 k
。请注意,在 Theano 中,符号微分仅适用于标量表达式,但可以针对多维张量进行微分。
在此更新中,add_multiply
方法不再 returns k
,因为它是不变的。出于类似的原因,Theano 版本现在接受 k
作为 non_sequence
.
我想在 Theano 中做这种循环:
def add_multiply(a,b, k):
return a+b+k, a*b*k, k
x=1
y=2
k=1
tuples = []
for i in range(5):
x,y,k = add_multiply(x,y,k)
tuples.append((x,y,k))
然而,当我这样做时
x0 = T.dvector('x0')
i = T.iscalar('i')
results,updates=th.scan(fn=add_multiply,outputs_info=[{'initial':x0,'taps':[-1]}],n_steps=i)
我得到 TypeError: add_multiply() takes exactly 3 arguments (1 given)
。如果我更改它以便函数采用单个元组代替,I get ValueError: length not known
特别是,我最终想根据 k 对整个结果进行微分。
第一个错误是因为您的 add_multiply
函数有 3 个参数,但是由于 outputs_info
列表中只有一个元素,您只提供了一个参数。目前尚不清楚您是希望 x0
向量仅作为 a
的初始值,还是希望它分布在 a
、b
和 [=17= 上]. Theano 不支持后者,而且通常 Theano 不支持元组。在 Theano 中,一切都必须是张量(例如,标量只是零维张量的特殊类型)。
您可以在 Theano 中实现 Python 实现的副本,如下所示。
import theano
import theano.tensor as tt
def add_multiply(a, b, k):
return a + b + k, a * b * k
def python_main():
x = 1
y = 2
k = 1
tuples = []
for i in range(5):
x, y = add_multiply(x, y, k)
tuples.append((x, y, k))
return tuples
def theano_main():
x = tt.constant(1, dtype='uint32')
y = tt.constant(2, dtype='uint32')
k = tt.scalar(dtype='uint32')
outputs, _ = theano.scan(add_multiply, outputs_info=[x, y], non_sequences=[k], n_steps=5)
g = theano.grad(tt.sum(outputs), k)
f = theano.function(inputs=[k], outputs=outputs + [g])
tuples = []
xvs, yvs, _ = f(1)
for xv, yv in zip(xvs, yvs):
tuples.append((xv, yv, 1))
return tuples
print 'Python:', python_main()
print 'Theano:', theano_main()
请注意,在 Theano 版本中,所有元组处理都发生在 Theano 之外; Python 必须将 Theano 函数返回的三个张量转换为元组列表。
更新:
不清楚 "the entire result" 应该指的是什么,但代码已更新以显示您可以如何区分 k
。请注意,在 Theano 中,符号微分仅适用于标量表达式,但可以针对多维张量进行微分。
在此更新中,add_multiply
方法不再 returns k
,因为它是不变的。出于类似的原因,Theano 版本现在接受 k
作为 non_sequence
.