使用 numba 加速 dtaidistance 键功能

speedup dtaidistance key function with numba

DTAIDistance 包可用于查找输入查询的 k 最佳匹配。但不能用于多维输入查询。此外,我想在一个 运行.

中找到许多输入查询的 k 最佳匹配

我修改了DTAIDistance函数,使其可以用于多查询的多维子序列搜索。我使用 njit 和 parallel 来加快进程,i.e.the p_calc 函数将 numba-parallel 应用于每个输入查询。但我发现并行计算似乎并没有加快计算速度,而只是简单地一个接一个地循环输入查询,即 calc 函数。

import time
from tqdm import tqdm
from numba import njit, prange
import numpy as np
inf = np.inf
argmin=np.argmin
@njit(fastmath=True, nogil=True, error_model="numpy", cache=True, parallel=False)
def p_calc(d, dtw, s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
    n_series = s1.shape[1]
    ndim = s1.shape[2]
    # s1 = np.ascontiguousarray(s1)#.shape
    # s2 = np.ascontiguousarray(s2)#.shape
    # dtw = np.full((n_series,r + 1, c + 1), np.inf,dtype=s1.dtype)  # cmath.inf
    # d = np.full((n_series), np.inf,dtype=s1.dtype)  # cmath.inf
    for i in range(psi_2b + 1):
        dtw[:, 0, i] = 0
    for i in range(psi_1b + 1):
        dtw[:, i, 0] = 0
    for nn in prange(n_series):
        print('im alive...')
        i0 = 1
        i1 = 0
        sc = 0
        ec = 0
        smaller_found = False
        ec_next = 0
        for i in range(r):
            i0 = i
            i1 = i + 1
            j_start = max(0, i - max(0, r - c) - window + 1)
            j_end = min(c, i + max(0, c - r) + window)
            if sc > j_start:
                j_start = sc
            smaller_found = False
            ec_next = i
        for j in range(j_start, j_end):
            val = 0
            tmp = ((s1[i, nn] - s2[j]) ** 2)
            # tmp = (np.abs(s1[i, nn] - s2[j, 0]))
            for nd in range(ndim):
                val += tmp[nd]
            d[nn] = val
            # d = np.sum(np.abs(s1[i] - s2[j]) )  # multi-d
            if max_step is not None and d[nn] > max_step:
                continue
            # print(i, j + 1 - skip, j - skipp, j + 1 - skipp, j - skip)
            dtw[nn, i1, j + 1] = d[nn] + min(dtw[nn, i0, j],
                                             dtw[nn, i0, j + 1] + penalty,
                                             dtw[nn, i1, j] + penalty)
            # dtw[i + 1, j + 1 - skip] = d + min(dtw[i + 1, j + 1 - skip], dtw[i + 1, j - skip])
            if dtw[nn, i1, j + 1] > max_dist:
                if not smaller_found:
                    sc = j + 1
                if j >= ec:
                    break
            else:
                smaller_found = True
                ec_next = j + 1
        ec = ec_next
    # Decide which d to return
    dtw[nn] = np.sqrt(dtw[nn])
    if psi_1e == 0 and psi_2e == 0:
        d[nn] = dtw[nn, i1, min(c, c + window - 1)]
    else:
        ir = i1
        ic = min(c, c + window - 1)
        if psi_1e != 0:
            vr = dtw[nn, ir:max(0, ir - psi_1e - 1):-1, ic]
            mir = np.argmin(vr)
            vr_mir = vr[mir]
        else:
            mir = ir
            vr_mir = inf
        if psi_2e != 0:
            vc = dtw[nn, ir, ic:max(0, ic - psi_2e - 1):-1]
            mic = np.argmin(vc)
            vc_mic = vc[mic]
        else:
            mic = ic
            vc_mic = inf
        if vr_mir < vc_mic:
            if psi_neg:
                dtw[nn, ir:ir - mir:-1, ic] = -1
            d[nn] = vr_mir
        else:
            if psi_neg:
                dtw[nn, ir, ic:ic - mic:-1] = -1
            d[nn] = vc_mic
    if max_dist and d[nn] ** 2 > max_dist:
        # if max_dist and d[nn] > max_dist:
        d[nn] = inf
return d, dtw


@njit(fastmath=True, nogil=True)  # Set "nopython" mode for best performance, equivalent to @njit
def calc(s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
    dtw = np.full((r + 1, c + 1), np.inf)  # cmath.inf
    for i in range(psi_2b + 1):
        dtw[0, i] = 0
    for i in range(psi_1b + 1):
        dtw[i, 0] = 0
    i0 = 1
    i1 = 0
    sc = 0
    ec = 0
    smaller_found = False
    ec_next = 0
    for i in range(r):
        i0 = i
        i1 = i + 1
        j_start = max(0, i - max(0, r - c) - window + 1)
        j_end = min(c, i + max(0, c - r) + window)
        if sc > j_start:
            j_start = sc
        smaller_found = False
        ec_next = i
        for j in range(j_start, j_end):
            # d = (s1[i] - s2[j]) ** 2# 1-d
            d = np.sum((s1[i] - s2[j]) ** 2)  # multi-d
            # d = np.sum(np.abs(s1[i] - s2[j]) )  # multi-d
            if max_step is not None and d > max_step:
                continue
            dtw[i1, j + 1] = d + min(dtw[i0, j],
                                     dtw[i0, j + 1] + penalty,
                                     dtw[i1, j] + penalty)
            if dtw[i1, j + 1] > max_dist:
                if not smaller_found:
                    sc = j + 1
                if j >= ec:
                    break
            else:
                smaller_found = True
                ec_next = j + 1
        ec = ec_next
    # Decide which d to return
    dtw = np.sqrt(dtw)
    if psi_1e == 0 and psi_2e == 0:
        d = dtw[i1, min(c, c + window - 1)]
    else:
        ir = i1
        ic = min(c, c + window - 1)
        if psi_1e != 0:
            vr = dtw[ir:max(0, ir - psi_1e - 1):-1, ic]
            mir = argmin(vr)
            vr_mir = vr[mir]
        else:
            mir = ir
            vr_mir = inf
        if psi_2e != 0:
            vc = dtw[ir, ic:max(0, ic - psi_2e - 1):-1]
            mic = argmin(vc)
            vc_mic = vc[mic]
        else:
            mic = ic
            vc_mic = inf
        if vr_mir < vc_mic:
            if psi_neg:
                dtw[ir:ir - mir:-1, ic] = -1
            d = vr_mir
        else:
            if psi_neg:
                dtw[ir, ic:ic - mic:-1] = -1
            d = vc_mic
    if max_dist and d * d > max_dist:
        d = inf
    return d, dtw


mydtype = np.float32
series1 = np.random.random((16, 30, 2)).astype(mydtype)
series2 = np.random.random((100000,  2)).astype(mydtype)
n_series = series1.shape[1]
r = series1.shape[0]
c = series2.shape[0]
dtw = np.full((n_series, r + 1, c + 1), np.inf, dtype=mydtype)  # cmath.inf
d = np.full((n_series), np.inf, dtype=mydtype)  # cmath.inf
time1 = time.time()
d, dtw1 = p_calc(d, dtw, series1, series2, series1.shape[0], series2.shape[0], 0, 0,
               series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)

time1 = time.time()
for ii in tqdm(range(series1.shape[1])):
    d, dtw1 = calc( series1[:, ii, :], series2, series1.shape[0], series2.shape[0], 0, 0,
                   series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)#   this one is faster

如何加速 calc 函数或 p_calc 函数,以便计算多维多查询的动态时间规整路径?

感谢您的回答,然后我修改了代码以进行简化。 我删除了 np.sum 部分并使用循环,我可以获得另一个加速。有进一步加速的建议吗?

import time
from numba import njit, prange
import numpy as np
inf = np.inf
argmin=np.argmin
@njit(fastmath=True, nogil=True, error_model="numpy", cache=False, parallel=True)
def p_calc(d, dtw, s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
    n_series = s1.shape[1]
    ndim = s1.shape[2]
    for nn in prange(n_series):
        for i in range(r):
            j_start = 0
            j_end = c
            for j in range(j_start, j_end):
                val = 0
                # tmp = ((s1[i, nn] - s2[j]) ** 2)
                # tmp = (np.abs(s1[i, nn] - s2[j, 0]))
                for nd in range(ndim):
                    tmp = ((s1[i, nn,nd] - s2[j,nd]) ** 2)
                    val += tmp
                d[nn] = val
    return d, dtw


@njit(fastmath=True, nogil=True)  # Set "nopython" mode for best performance, equivalent to @njit
def calc(dtw,s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
    ndim = s1.shape[-1]
    for i in range(r):
        j_start = 0
        j_end = c
        for j in range(j_start, j_end):
            d = 0
            for kk in range(ndim):
                d += (s1[i, kk] - s2[j, kk]) ** 2
    return d, dtw


mydtype = np.float32
series1 = np.random.random((16, 300, 2)).astype(mydtype)
series2 = np.random.random((1000000,  2)).astype(mydtype)
n_series = series1.shape[1]
r = series1.shape[0]
c = series2.shape[0]
dtw = np.full((n_series, r + 1, c + 1), np.inf, dtype=mydtype)  # cmath.inf
d = np.full((n_series), np.inf, dtype=mydtype)  # cmath.inf
time1 = time.time()
# assert 1==2
# dtw[:,series2.shape[0]]
d1, dtw1 = p_calc(d, dtw, series1, series2, series1.shape[0], series2.shape[0], 0, 0, series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)
# assert 1==2
time1 = time.time()
dtw = np.full(( r + 1, c + 1), np.inf, dtype=mydtype)  # cmath.inf
for ii in (range(series1.shape[1])):
    d2, dtw2 = calc( dtw,series1[:, ii, :], series2, series1.shape[0], series2.shape[0], 0, 0,
                   series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)#   this one is faster
np.allclose(dtw1[-1],dtw2)
np.allclose(d1[-1],d2)

编辑: 我发现如果使用 passbreak,以下代码的性能会有很大不同。我不明白为什么?

@njit(fastmath=True, nogil=True)
def kbest_matches(matching,k=4000):
    ki = 0
    while  ki < k:
        best_idx =np.argmin(matching)# np.argmin(np.arange(10000000))#
        if best_idx == 0 :
            # pass
            break
        ki += 1
    return 0

ss= np.random.random((1575822,))
time1 = time.time()
pp = kbest_matches(ss)
print(time.time() - time1)

我假设这两种实现的代码都是正确的并且已经过仔细检查(否则基准测试将毫无意义)。

问题可能出在函数的编译时间上。事实上,即使使用 cache=True,第一次调用也比下一次调用慢得多。这对于并行实现尤其重要,因为编译并行 Numba 代码通常较慢(因为它更复杂)。避免这种情况的最佳解决方案是通过向 Numba 提供类型 来提前 编译 Numba 函数。

除此之外,仅对一次计算进行基准测试通常被认为是一种不好的做法。好的基准测试 执行多次迭代 并删除第一个(或单独考虑)。事实上,当代码第一次执行时,可能会出现其他几个问题:CPU 缓存(和 TLB)是冷的,CPU 频率可以在执行过程中发生变化,并且在程序刚启动时可能会变小,可能需要页面错误

实际上,我无法重现该问题。实际上,p_calc 在我的 6 核机器上快了 3.3 倍。当benchmark在5次迭代的循环中完成时,并行实现的测量时间要小得多:大约13倍(这对于在6核机器上使用6线程的并行实现实际上是可疑的)。