如何在 Theano 的 TensorVariable 上执行范围?

How to perform a range on a Theano's TensorVariable?

如何对 Theano 的 TensorVariable 进行取值?

示例:

import theano.tensor as T
from theano import function

constant = T.dscalar('constant')
n_iters = T.dscalar('n_iters')
start = T.dscalar('start')
result = start  

for iter in range(n_iters):
    result = start + constant

f = function([start, constant, n_iters], result)
print('f(0,2,5): {0}'.format(f(1,2)))

returns错误:

Traceback (most recent call last):
  File "test_theano.py", line 9, in <module>
    for iter in range(n_iters):
TypeError: range() integer end argument expected, got TensorVariable.

在 Theano 的 TensorVariable 上使用范围的正确方法是什么?

不清楚这段代码试图做什么,因为即使循环有效,它也不会计算任何有用的东西:result 将始终等于 start 和 [=16 的总和=] 与迭代次数无关。

我会假设意图是像这样计算 result

for iter in range(n_iters):
    result = result + constant

问题是你将符号、延迟执行、Theano 操作与非符号、立即执行、Python 操作混合在一起。

range 是一个 Python 函数,它需要一个 Python 整数参数,但您提供的是 Theano 符号值(n_iters,它具有双精度类型一个整数类型,但我们假设它实际上是一个 iscalar 而不是 dscalar)。就 Python 而言,所有 Theano 符号张量都只是对象:Theano 库中某处 class 类型的实例;它们肯定不是整数。即使你眯着眼睛试图假装 Theano iscalar 看起来像一个 Python 整数,它仍然不起作用,因为 Python 操作立即执行,这意味着 n_iters 需要一个值立即。另一方面,Theano 没有任何 iscalar 的值,直到通过调用已编译的 Theano 函数(或通过 eval)提供一个值。

要创建符号范围,您可以使用 theano.tensor.arange,其操作方式与 NumPy 的 arange 类似,但只是符号化。

示例:

import theano.tensor as T
from theano import function

my_range_max = T.iscalar('my_range_max')
my_range = T.arange(my_range_max)

f = function([my_range_max], my_range)
print('f(10): {0}'.format(f(10)))

输出:

f(10): [0 1 2 3 4 5 6 7 8 9]

通过使 n_iters 成为符号变量,您隐含地表示 "I don't know how many iterations there need to be in this loop until a value is provided for n_iters later"。既然如此,您必须使用符号循环而不是 Python for 循环。在后一种情况下,您必须告诉 Python 要迭代多少次 现在 ,您不能延迟该决定,直到稍后为 n_iters 提供值。要解决这个问题,您需要切换到符号循环,这是由 Theano 的 scan 运算符提供的。

这里的代码更改为使用 scan(以及其他假定的更改)。

import theano
import theano.tensor as T
from theano import function

constant = T.dscalar('constant')
n_iters = T.iscalar('n_iters')
start = T.dscalar('start')

results, _ = theano.scan(lambda result, constant: result + constant,
                         outputs_info=[start], non_sequences=[constant], n_steps=n_iters)

f = function([start, constant, n_iters], results[-1])
print('f(0,2,5): {0}'.format(f(0, 2, 5)))