如何在嵌套循环中高效计算上三角的 logsumexp?
How to efficiently compute logsumexp of upper triangle in a nested loop?
我有一个嵌套的 for 循环,它遍历权重矩阵的行,并将 logsumexp 应用到这些权重行的外部加法矩阵的上三角部分。它非常慢,所以我试图弄清楚如何通过矢量化或取出循环来代替矩阵运算来加快速度。
'''
Wm: weights matrix, nxk
W: updated weights matrix, nxn
triu_inds: upper triangular indices of Wxy outer matrix
'''
for x in range(n-1):
wx = Wm[x, :]
for y in range(x+1, n):
wy = Wm[y, :]
Wxy = np.add.outer(wx, wy)
Wxy = Wxy[triu_inds]
W[x, y] = logsumexp(Wxy)
logsumexp:计算输入数组的指数总和的对数
a: [1, 2, 3]
logsumexp(a) = log( exp(1) + exp(2) + exp(3) )
输入数据Wm是一个nxk维的权重矩阵。 K 表示患者传感器位置,n 表示所有此类可能的传感器位置。 Wm 中的值基本上是患者传感器与已知传感器的接近程度。
示例:
Wm = [1 2 3]
[4 5 6]
[7 8 9]
[10 11 12]
wx = [1 2 3]
wy = [4 5 6]
Wxy = [5 6 7]
[6 7 8]
[7 8 9]
triu_indices = ([0, 0, 1], [1, 2, 2])
Wxy[triu_inds] = [6, 7, 8]
logsumexp(Wxy[triu_inds]) = log(exp(6) + exp(7) + exp(8))
您可以对整个矩阵执行外积 Wm
,然后交换对应于操作数 1 中的列和操作数 2 中的行的轴,以便将三角形索引应用于列。生成的矩阵填充了所有行的组合,所以你需要 select 上三角部分。
W = logsumexp(
np.add.outer(Wm, Wm).swapaxes(1, 2)[(slice(None),)*2 + triu_inds],
axis=-1 # Perform summation over last axis.
)
W = np.triu(W, k=1)
我有一个嵌套的 for 循环,它遍历权重矩阵的行,并将 logsumexp 应用到这些权重行的外部加法矩阵的上三角部分。它非常慢,所以我试图弄清楚如何通过矢量化或取出循环来代替矩阵运算来加快速度。
'''
Wm: weights matrix, nxk
W: updated weights matrix, nxn
triu_inds: upper triangular indices of Wxy outer matrix
'''
for x in range(n-1):
wx = Wm[x, :]
for y in range(x+1, n):
wy = Wm[y, :]
Wxy = np.add.outer(wx, wy)
Wxy = Wxy[triu_inds]
W[x, y] = logsumexp(Wxy)
logsumexp:计算输入数组的指数总和的对数
a: [1, 2, 3]
logsumexp(a) = log( exp(1) + exp(2) + exp(3) )
输入数据Wm是一个nxk维的权重矩阵。 K 表示患者传感器位置,n 表示所有此类可能的传感器位置。 Wm 中的值基本上是患者传感器与已知传感器的接近程度。
示例:
Wm = [1 2 3]
[4 5 6]
[7 8 9]
[10 11 12]
wx = [1 2 3]
wy = [4 5 6]
Wxy = [5 6 7]
[6 7 8]
[7 8 9]
triu_indices = ([0, 0, 1], [1, 2, 2])
Wxy[triu_inds] = [6, 7, 8]
logsumexp(Wxy[triu_inds]) = log(exp(6) + exp(7) + exp(8))
您可以对整个矩阵执行外积 Wm
,然后交换对应于操作数 1 中的列和操作数 2 中的行的轴,以便将三角形索引应用于列。生成的矩阵填充了所有行的组合,所以你需要 select 上三角部分。
W = logsumexp(
np.add.outer(Wm, Wm).swapaxes(1, 2)[(slice(None),)*2 + triu_inds],
axis=-1 # Perform summation over last axis.
)
W = np.triu(W, k=1)