Theano 中的符号矩阵幂级数
Symbolic matrix power series in Theano
假设我有一个二维张量 A
。我想象征性地计算 Apow
,A
的幂级数,这是一个定义如下的 3d 张量:
Apow = [I, A, A^2, A^3, ..., A^k]
其中 A^2
表示 A.dot(A)
(即幂级数是根据点积而不是元素方式定义的)。 k
是一个符号标量,指定系列的长度。
我如何在 Theano 中实现它?解决方案似乎是基于 scan
,但我无法让它发挥作用。
有什么想法吗?
让我将我的答案分成 numpy 实现和 theano 实现:
使用 numpy:
def kpow(A, k):
if k == 0:
return np.identity(A.shape[0])
if k == 1:
return A
else:
return np.dot(A, kpow(A, k-1))
然后您可以这样得到您的 Apow
:
k = 5
A = np.array([[1, 1, 1],[2, 2, 2],[3, 3, 3]])
Apow = [kpow(A,i) for i in range(k)]
当然,您可以通过实际积累列表来提高这种方式的效率。需要注意的重要一点是递归,即我们如何使用先前的结果来计算下一个。
使用theano:
首先,让我们为 k
和矩阵 M
定义两个符号变量:
k = T.iscalar('k')
M = T.dmatrix('M')
接下来,让我们定义一个循环函数:
def recurrence(M, prev_result):
return prev_result * M
终于到了扫描功能的时候了:
result, updates = theano.scan(fn=recurrence,
outputs_info=T.identity_like(M),
non_sequences=M,
n_steps=k)
现在让我们得到一些结果:
A = np.array([[1, 1, 1],[2, 2, 2],[3, 3, 3]], dtype='int32')
kpow_theano = theano.function(inputs=[M,k], outputs=result)
Apow = [kpow_theano(A,10)[i] for i in range(10)]
我不确定你如何使用 theano 在前面获得单位矩阵。我想你可以将它添加到列表中。
假设我有一个二维张量 A
。我想象征性地计算 Apow
,A
的幂级数,这是一个定义如下的 3d 张量:
Apow = [I, A, A^2, A^3, ..., A^k]
其中 A^2
表示 A.dot(A)
(即幂级数是根据点积而不是元素方式定义的)。 k
是一个符号标量,指定系列的长度。
我如何在 Theano 中实现它?解决方案似乎是基于 scan
,但我无法让它发挥作用。
有什么想法吗?
让我将我的答案分成 numpy 实现和 theano 实现:
使用 numpy:
def kpow(A, k):
if k == 0:
return np.identity(A.shape[0])
if k == 1:
return A
else:
return np.dot(A, kpow(A, k-1))
然后您可以这样得到您的 Apow
:
k = 5
A = np.array([[1, 1, 1],[2, 2, 2],[3, 3, 3]])
Apow = [kpow(A,i) for i in range(k)]
当然,您可以通过实际积累列表来提高这种方式的效率。需要注意的重要一点是递归,即我们如何使用先前的结果来计算下一个。
使用theano:
首先,让我们为 k
和矩阵 M
定义两个符号变量:
k = T.iscalar('k')
M = T.dmatrix('M')
接下来,让我们定义一个循环函数:
def recurrence(M, prev_result):
return prev_result * M
终于到了扫描功能的时候了:
result, updates = theano.scan(fn=recurrence,
outputs_info=T.identity_like(M),
non_sequences=M,
n_steps=k)
现在让我们得到一些结果:
A = np.array([[1, 1, 1],[2, 2, 2],[3, 3, 3]], dtype='int32')
kpow_theano = theano.function(inputs=[M,k], outputs=result)
Apow = [kpow_theano(A,10)[i] for i in range(10)]
我不确定你如何使用 theano 在前面获得单位矩阵。我想你可以将它添加到列表中。