昂贵轧制 window 产品的有效总和

Efficient sum of expensive rolling window products

给定 i = 0 到 N-1 的数字序列 a[i],我正在尝试计算以下总和:

a[0] * a[1] * a[2] +
a[1] * a[2] * a[3] +
a[2] * a[3] * a[4] +
...
a[N-4] * a[N-3] * a[N-2] +
a[N-3] * a[N-2] * a[N-1]

我想将乘组的大小 G(在上面的示例 3 中)设置为可变参数。然后,可以使用简单的 O(N*G) 算法天真地获得结果,该算法可以用如下伪代码编写:

sum = 0
for i from 0 to (N-G-1):
  group_contribution = 1
  for j from 0 to (G-1):
    group_contribution *= a[i+j]
  sum += group_contribution

然而,对于大 G,很明显该算法效率极低,尤其是假设序列 a[i] 的数字事先未知并且必须在运行时进行昂贵的计算。

出于这个原因,我考虑使用以下复杂度 O(N+G) 的算法,该算法通过计算滚动乘积来回收序列 a[i] 的值:

sum = 0
rolling_product = 1
for i from 0 to (G-1):
  rolling_product *= a[i]
sum += rolling_product
for i from G to (N-1):
  rolling_product /= a[i-G]
  rolling_product *= a[i]
  sum += rolling_product

然而,我担心标准浮点表示中除法的数值稳定性。

我很想知道是否有稳定、快速的方法来计算此总和。对我来说这感觉像是一项基本的数字任务,但目前我不知道如何有效地完成它。

感谢您的任何想法!

你的滚动产品是个好主意,但如你所说,它在稳定性方面存在问题。我会这样解决:

  • 分别使用类似的系统来跟踪零和负数的数量。这些是整数和,所以没有稳定性问题。
  • 不是计算所有 a[i] 的滚动乘积,而是计算 log(abs(a[i])) 的滚动 sum,不包括零。然后,当您需要产品时,它是 (num_zeros > 0 ? 0.0 : exp(log_sum)) * sign。这将解决主要的不稳定问题。
  • 当您从巧妙的滚动 log_sum 算法中生成输出时,您应该同时构建一个 fresh log_sum减去任何东西。当新总和中的元素数量达到 G 时,然后用该数字覆盖滚动 llog_sum 并将其重置为零。这将消除长期累积的任何舍入误差。

作为序言,您可以考虑 运行 两种算法的一些测试用例并比较结果(例如,作为相对误差)。

接下来,如果你有额外的内存和时间,这里有一个 O(N log2 G) 时间和内存。它类似于常数时间的方法,线性对数 space 到 range minimum query 问题。

二次幂范围的预计算乘积

B[i][j]是2[=的乘积84=]j a 的元素从位置 i 开始,所以

B[i][j] = a[i] × a[i + 1] × ... × a[i + 2j - 1]

我们对 N log2 G 中的值感兴趣 B,即0≤j≤log2G。我们可以在 O(1) 中计算这些值中的每一个,因为

B[i][j] = B[i][j - 1] × B[ i + 2j - 1][j - 1]

计算总和

为了计算总和中的一项,我们将 G 分解为二次方大小的块。例如,如果 G = 13,则第一项为

a[0] × ... × a[12] = (a[0] × ... × a[7]) × (a[8] × ... × a[11]) × a[12] = B[0][3] × B [8][2] × B[12][0]

每个 O(N) 项都可以在 O(log2 G) 时间,因此求和的总复杂度为 O(N log2 G ).

创建一个新序列 b,其中 b[0] = a[0] * a[1] * a[2] * ... a[G-1] 等

现在你有一个更简单的问题,计算 b 值的总和,你可以保持总和,每次添加一个值时减去 b[0] 并添加新值并将它们全部向下滑动一个 (使用循环缓冲区,所以没有任何动作)。 [典型滑动window移动平均类型代码]

保留最后 G a[] 值的缓存并计算要添加到末尾的新值只是 O(G) 操作,您只计算 a[i] 一次。

是的,如果你仔细计算反向部分积,你不需要划分。

def window_products(seq, g):
    lst = list(seq)
    reverse_products = lst[:]
    for i in range(len(lst) - 2, -1, -1):
        if i % g != len(lst) % g:
            reverse_products[i] *= reverse_products[i + 1]
    product = 1
    for i in range(len(lst) - g + 1):
        yield reverse_products[i] * product
        if i % g == len(lst) % g:
            product = 1
        else:
            product *= lst[i + g]


print(list(window_products(range(10), 1)))
print(list(window_products(range(10), 2)))
print(list(window_products(range(10), 3)))