NumPy - np.searchsorted 用于二维数组

NumPy - np.searchsorted for 2-D arrays

np.searchsorted 仅适用于一维数组。

我有一个 lexicographically sorted 二维数组,这意味着第 0 行被排序,然后对于第 0 行的相同值,第 1 行的对应元素也被排序,对于相同的值 1第 2 行的第 - 行值也已排序。换句话说,由列组成的元组被排序。

我有一些其他带有元组列的二维数组,需要将其插入到第一个二维数组的正确列位置中。对于 1D 情况,通常使用 np.searchsorted 来找到正确的位置。

但对于二维数组,是否有 np.searchsorted 的替代方案?与 np.lexsort is a 2D alternative for 1D np.argsort.

类似的东西

如果没有这样的函数,那么是否可以使用现有的 numpy 函数以有效的方式实现此功能?

我对任何 dtype 数组的有效解决方案感兴趣,包括 np.object_

处理任何 dtype 情况的一种天真的方法是将两个数组的每一列转换为一维数组(或元组),然后将这些列存储为 dtype = np.object_ 的另一个一维数组。也许它不是那么天真,甚至可以更快,尤其是当列很高时。

发布我在问题中提到的第一个天真的解决方案,它只是将二维数组转换为 dtype = np.object_ 的一维数组,其中包含原始列作为 Python 元组,然后使用一维 np.searchsorted,解决方案适用于任何 dtype。事实上,这个解决方案并不那么天真,它非常快,正如我对当前问题的另一个答案所衡量的那样,尤其是对于长度低于 100 的密钥来说它很快。

Try it online!

import numpy as np
np.random.seed(0)

def to_obj(x):
    res = np.empty((x.shape[0],), dtype = np.object_)
    res[:] = [tuple(np.squeeze(e, 0)) for e in np.split(x, x.shape[0], axis = 0)]
    return res

a = np.random.randint(0, 3, (10, 23))
b = np.random.randint(0, 3, (10, 15))

a, b = [x[:, np.lexsort(x[::-1])] for x in (a, b)]

print(np.concatenate((np.arange(a.shape[1])[None, :], a)), '\n\n', b, '\n')

a, b = [to_obj(x.T) for x in (a, b)]

print(np.searchsorted(a, b))

我已经创建了几个更高级的策略。

还实施了像 中那样使用 tuples 的简单策略。

测量所有解决方案的时间。

大多数策略都使用 np.searchsorted 作为基础引擎。为了实施这些高级策略,使用了特殊的包装 class _CmpIx 以便为 np.searchsorted 调用提供自定义比较函数 (__lt__)。

  1. py.tuples 策略只是将所有列转换为元组并将它们存储为 np.object_ dtype 的 numpy 一维数组,然后进行常规搜索排序。
  2. py.zip 使用 python 的 zip 来懒惰地执行相同的任务。
  3. np.lexsort 策略只是使用 np.lexsort 来按字典顺序比较两列。
  4. np.nonzero 使用 np.flatnonzero(a != b) 表达式。
  5. cmp_numba_CmpIx 包装器中使用 ahead of time 编译的 numba 代码,用于快速字典顺序惰性比较两个提供的元素。
  6. np.searchsorted 使用标准的 numpy 函数,但仅针对一维情况进行测量。
  7. 对于 numba 策略,整个搜索算法是使用 Numba engine, algorithm is based on binary search 从头开始​​实施的。该算法有 _py_nm 变体,_nm 由于使用 Numba 编译器而更快,而 _py 是相同的算法,但 un-compiled。还有 _sorted 风格,它对要插入的数组进行额外优化,已经排序。
  8. view1d - @MadPhysicist 建议的方法。在代码中将它们注释掉,因为对于所有密钥长度 >1 的大多数测试,它们都返回了错误的答案,这可能是由于原始查看数组的一些问题。

Try it online!

class SearchSorted2D:
    class _CmpIx:
        def __init__(self, t, p, i):
            self.p, self.i = p, i
            self.leg = self.leg_cache()[t]
            self.lt = lambda o: self.leg(self, o, False) if self.i != o.i else False
            self.le = lambda o: self.leg(self, o, True) if self.i != o.i else True
        @classmethod
        def leg_cache(cls):
            if not hasattr(cls, 'leg_cache_data'):
                cls.leg_cache_data = {
                    'py.zip': cls._leg_py_zip, 'np.lexsort': cls._leg_np_lexsort,
                    'np.nonzero': cls._leg_np_nonzero, 'cmp_numba': cls._leg_numba_create(),
                }
            return cls.leg_cache_data
        def __eq__(self, o): return not self.lt(o) and self.le(o)
        def __ne__(self, o): return self.lt(o) or not self.le(o)
        def __lt__(self, o): return self.lt(o)
        def __le__(self, o): return self.le(o)
        def __gt__(self, o): return not self.le(o)
        def __ge__(self, o): return not self.lt(o)
        @staticmethod
        def _leg_np_lexsort(self, o, eq):
            import numpy as np
            ia, ib = (self.i, o.i) if eq else (o.i, self.i)
            return (np.lexsort(self.p.ab[::-1, ia : (ib + (-1, 1)[ib >= ia], None)[ib == 0] : ib - ia])[0] == 0) == eq
        @staticmethod
        def _leg_py_zip(self, o, eq):
            for l, r in zip(self.p.ab[:, self.i], self.p.ab[:, o.i]):
                if l < r:
                    return True
                if l > r:
                    return False
            return eq
        @staticmethod
        def _leg_np_nonzero(self, o, eq):
            import numpy as np
            a, b = self.p.ab[:, self.i], self.p.ab[:, o.i]
            ix = np.flatnonzero(a != b)
            return a[ix[0]] < b[ix[0]] if ix.size != 0 else eq
        @staticmethod
        def _leg_numba_create():
            import numpy as np

            try:
                from numba.pycc import CC
                cc = CC('ss_numba_mod')
                @cc.export('ss_numba_i8', 'b1(i8[:],i8[:],b1)')
                def ss_numba(a, b, eq):
                    for i in range(a.size):
                        if a[i] < b[i]:
                            return True
                        elif b[i] < a[i]:
                            return False
                    return eq
                cc.compile()
                success = True
            except:    
                success = False
                
            if success:
                try:
                    import ss_numba_mod
                except:
                    success = False
            
            def odo(self, o, eq):
                a, b = self.p.ab[:, self.i], self.p.ab[:, o.i]
                assert a.ndim == 1 and a.shape == b.shape, (a.shape, b.shape)
                return ss_numba_mod.ss_numba_i8(a, b, eq)
                
            return odo if success else None

    def __init__(self, type_):
        import numpy as np
        self.type_ = type_
        self.ci = np.array([], dtype = np.object_)
    def __call__(self, a, b, *pargs, **nargs):
        import numpy as np
        self.ab = np.concatenate((a, b), axis = 1)
        self._grow(self.ab.shape[1])
        ix = np.searchsorted(self.ci[:a.shape[1]], self.ci[a.shape[1] : a.shape[1] + b.shape[1]], *pargs, **nargs)
        return ix
    def _grow(self, to):
        import numpy as np
        if self.ci.size >= to:
            return
        import math
        to = 1 << math.ceil(math.log(to) / math.log(2))
        self.ci = np.concatenate((self.ci, [self._CmpIx(self.type_, self, i) for i in range(self.ci.size, to)]))

class SearchSorted2DNumba:
    @classmethod
    def do(cls, a, v, side = 'left', *, vsorted = False, numba_ = True):
        import numpy as np

        if not hasattr(cls, '_ido_numba'):
            def _ido_regular(a, b, vsorted, lrt):
                nk, na, nb = a.shape[0], a.shape[1], b.shape[1]
                res = np.zeros((2, nb), dtype = np.int64)
                max_depth = 0
                if nb == 0:
                    return res, max_depth
                #lb, le, rb, re = 0, 0, 0, 0
                lrb, lre = 0, 0
                
                if vsorted:
                    brngs = np.zeros((nb, 6), dtype = np.int64)
                    brngs[0, :4] = (-1, 0, nb >> 1, nb)
                    i, j, size = 0, 1, 1
                    while i < j:
                        for k in range(i, j):
                            cbrng = brngs[k]
                            bp, bb, bm, be = cbrng[:4]
                            if bb < bm:
                                brngs[size, :4] = (k, bb, (bb + bm) >> 1, bm)
                                size += 1
                            bmp1 = bm + 1
                            if bmp1 < be:
                                brngs[size, :4] = (k, bmp1, (bmp1 + be) >> 1, be)
                                size += 1
                        i, j = j, size
                    assert size == nb
                    brngs[:, 4:] = -1

                for ibc in range(nb):
                    if not vsorted:
                        ib, lrb, lre = ibc, 0, na
                    else:
                        ibpi, ib = int(brngs[ibc, 0]), int(brngs[ibc, 2])
                        if ibpi == -1:
                            lrb, lre = 0, na
                        else:
                            ibp = int(brngs[ibpi, 2])
                            if ib < ibp:
                                lrb, lre = int(brngs[ibpi, 4]), int(res[1, ibp])
                            else:
                                lrb, lre = int(res[0, ibp]), int(brngs[ibpi, 5])
                        brngs[ibc, 4 : 6] = (lrb, lre)
                        assert lrb != -1 and lre != -1
                        
                    for ik in range(nk):
                        if lrb >= lre:
                            if ik > max_depth:
                                max_depth = ik
                            break

                        bv = b[ik, ib]
                        
                        # Binary searches
                        
                        if nk != 1 or lrt == 2:
                            cb, ce = lrb, lre
                            while cb < ce:
                                cm = (cb + ce) >> 1
                                av = a[ik, cm]
                                if av < bv:
                                    cb = cm + 1
                                elif bv < av:
                                    ce = cm
                                else:
                                    break
                            lrb, lre = cb, ce
                                
                        if nk != 1 or lrt >= 1:
                            cb, ce = lrb, lre
                            while cb < ce:
                                cm = (cb + ce) >> 1
                                if not (bv < a[ik, cm]):
                                    cb = cm + 1
                                else:
                                    ce = cm
                            #rb, re = cb, ce
                            lre = ce
                                
                        if nk != 1 or lrt == 0 or lrt == 2:
                            cb, ce = lrb, lre
                            while cb < ce:
                                cm = (cb + ce) >> 1
                                if a[ik, cm] < bv:
                                    cb = cm + 1
                                else:
                                    ce = cm
                            #lb, le = cb, ce
                            lrb = cb
                            
                        #lrb, lre = lb, re
                            
                    res[:, ib] = (lrb, lre)
                    
                return res, max_depth

            cls._ido_regular = _ido_regular
            
            import numba
            cls._ido_numba = numba.jit(nopython = True, nogil = True, cache = True)(cls._ido_regular)
            
        assert side in ['left', 'right', 'left_right'], side
        a, v = np.array(a), np.array(v)
        assert a.ndim == 2 and v.ndim == 2 and a.shape[0] == v.shape[0], (a.shape, v.shape)
        res, max_depth = (cls._ido_numba if numba_ else cls._ido_regular)(
            a, v, vsorted, {'left': 0, 'right': 1, 'left_right': 2}[side],
        )
        return res[0] if side == 'left' else res[1] if side == 'right' else res

def Test():
    import time
    import numpy as np
    np.random.seed(0)
    
    def round_float_fixed_str(x, n = 0):
        if type(x) is int:
            return str(x)
        s = str(round(float(x), n))
        if n > 0:
            s += '0' * (n - (len(s) - 1 - s.rfind('.')))
        return s

    def to_tuples(x):
        r = np.empty([x.shape[1]], dtype = np.object_)
        r[:] = [tuple(e) for e in x.T]
        return r
    
    searchsorted2d = {
        'py.zip': SearchSorted2D('py.zip'),
        'np.nonzero': SearchSorted2D('np.nonzero'),
        'np.lexsort': SearchSorted2D('np.lexsort'),
        'cmp_numba': SearchSorted2D('cmp_numba'),
    }
    
    for iklen, klen in enumerate([1, 1, 2, 5, 10, 20, 50, 100, 200]):
        times = {}
        for side in ['left', 'right']:
            a = np.zeros((klen, 0), dtype = np.int64)
            tac = to_tuples(a)

            for itest in range((15, 100)[iklen == 0]):
                b = np.random.randint(0, (3, 100000)[iklen == 0], (klen, np.random.randint(1, (1000, 2000)[iklen == 0])), dtype = np.int64)
                b = b[:, np.lexsort(b[::-1])]
                
                if iklen == 0:
                    assert klen == 1, klen
                    ts = time.time()
                    ix1 = np.searchsorted(a[0], b[0], side = side)
                    te = time.time()
                    times['np.searchsorted'] = times.get('np.searchsorted', 0.) + te - ts
                    
                for cached in [False, True]:
                    ts = time.time()
                    tb = to_tuples(b)
                    ta = tac if cached else to_tuples(a)
                    ix1 = np.searchsorted(ta, tb, side = side)
                    if not cached:
                        ix0 = ix1
                    tac = np.insert(tac, ix0, tb) if cached else tac
                    te = time.time()
                    timesk = f'py.tuples{("", "_cached")[cached]}'
                    times[timesk] = times.get(timesk, 0.) + te - ts

                for type_ in searchsorted2d.keys():
                    if iklen == 0 and type_ in ['np.nonzero', 'np.lexsort']:
                        continue
                    ss = searchsorted2d[type_]
                    try:
                        ts = time.time()
                        ix1 = ss(a, b, side = side)
                        te = time.time()
                        times[type_] = times.get(type_, 0.) + te - ts
                        assert np.array_equal(ix0, ix1)
                    except Exception:
                        times[type_ + '!failed'] = 0.

                for numba_ in [False, True]:
                    for vsorted in [False, True]:
                        if numba_:
                            # Heat-up/pre-compile numba
                            SearchSorted2DNumba.do(a, b, side = side, vsorted = vsorted, numba_ = numba_)
                        
                        ts = time.time()
                        ix1 = SearchSorted2DNumba.do(a, b, side = side, vsorted = vsorted, numba_ = numba_)
                        te = time.time()
                        timesk = f'numba{("_py", "_nm")[numba_]}{("", "_sorted")[vsorted]}'
                        times[timesk] = times.get(timesk, 0.) + te - ts
                        assert np.array_equal(ix0, ix1)


                # View-1D methods suggested by @MadPhysicist
                if False: # Commented out as working just some-times
                    aT, bT = np.copy(a.T), np.copy(b.T)
                    assert aT.ndim == 2 and bT.ndim == 2 and aT.shape[1] == klen and bT.shape[1] == klen, (aT.shape, bT.shape, klen)
                    
                    for ty in ['if', 'cf']:
                        try:
                            dt = np.dtype({'if': [('', b.dtype)] * klen, 'cf': [('row', b.dtype, klen)]}[ty])
                            ts = time.time()
                            va = np.ndarray(aT.shape[:1], dtype = dt, buffer = aT)
                            vb = np.ndarray(bT.shape[:1], dtype = dt, buffer = bT)
                            ix1 = np.searchsorted(va, vb, side = side)
                            te = time.time()
                            assert np.array_equal(ix0, ix1), (ix0.shape, ix1.shape, ix0[:20], ix1[:20])
                            times[f'view1d_{ty}'] = times.get(f'view1d_{ty}', 0.) + te - ts
                        except Exception:
                            raise
                
                a = np.insert(a, ix0, b, axis = 1)
            
        stimes = ([f'key_len: {str(klen).rjust(3)}'] +
            [f'{k}: {round_float_fixed_str(v, 4).rjust(7)}' for k, v in times.items()])
        nlines = 4
        print('-' * 50 + '\n' + ('', '!LARGE!:\n')[iklen == 0], end = '')
        for i in range(nlines):
            print(',  '.join(stimes[len(stimes) * i // nlines : len(stimes) * (i + 1) // nlines]), flush = True)
            
Test()

输出:

--------------------------------------------------
!LARGE!:
key_len:   1,  np.searchsorted:  0.0250
py.tuples_cached:  3.3113,  py.tuples: 30.5263,  py.zip: 40.9785
cmp_numba: 25.7826,  numba_py:  3.6673
numba_py_sorted:  6.8926,  numba_nm:  0.0466,  numba_nm_sorted:  0.0505
--------------------------------------------------
key_len:   1,  py.tuples_cached:  0.1371
py.tuples:  0.4698,  py.zip:  1.2005,  np.nonzero:  4.7827
np.lexsort:  4.4672,  cmp_numba:  1.0644,  numba_py:  0.2748
numba_py_sorted:  0.5699,  numba_nm:  0.0005,  numba_nm_sorted:  0.0020
--------------------------------------------------
key_len:   2,  py.tuples_cached:  0.1131
py.tuples:  0.3643,  py.zip:  1.0670,  np.nonzero:  4.5199
np.lexsort:  3.4595,  cmp_numba:  0.8582,  numba_py:  0.4958
numba_py_sorted:  0.6454,  numba_nm:  0.0025,  numba_nm_sorted:  0.0025
--------------------------------------------------
key_len:   5,  py.tuples_cached:  0.1876
py.tuples:  0.4493,  py.zip:  1.6342,  np.nonzero:  5.5168
np.lexsort:  4.6086,  cmp_numba:  1.0939,  numba_py:  1.0607
numba_py_sorted:  0.9737,  numba_nm:  0.0050,  numba_nm_sorted:  0.0065
--------------------------------------------------
key_len:  10,  py.tuples_cached:  0.6017
py.tuples:  1.2275,  py.zip:  3.5276,  np.nonzero: 13.5460
np.lexsort: 12.4183,  cmp_numba:  2.5404,  numba_py:  2.8334
numba_py_sorted:  2.3991,  numba_nm:  0.0165,  numba_nm_sorted:  0.0155
--------------------------------------------------
key_len:  20,  py.tuples_cached:  0.8316
py.tuples:  1.3759,  py.zip:  3.4238,  np.nonzero: 13.7834
np.lexsort: 16.2164,  cmp_numba:  2.4483,  numba_py:  2.6405
numba_py_sorted:  2.2226,  numba_nm:  0.0170,  numba_nm_sorted:  0.0160
--------------------------------------------------
key_len:  50,  py.tuples_cached:  1.0443
py.tuples:  1.4085,  py.zip:  2.2475,  np.nonzero:  9.1673
np.lexsort: 19.5266,  cmp_numba:  1.6181,  numba_py:  1.7731
numba_py_sorted:  1.4637,  numba_nm:  0.0415,  numba_nm_sorted:  0.0405
--------------------------------------------------
key_len: 100,  py.tuples_cached:  2.0136
py.tuples:  2.5380,  py.zip:  2.2279,  np.nonzero:  9.2929
np.lexsort: 33.9505,  cmp_numba:  1.5722,  numba_py:  1.7158
numba_py_sorted:  1.4208,  numba_nm:  0.0871,  numba_nm_sorted:  0.0851
--------------------------------------------------
key_len: 200,  py.tuples_cached:  3.5945
py.tuples:  4.1847,  py.zip:  2.3553,  np.nonzero: 11.3781
np.lexsort: 66.0104,  cmp_numba:  1.8153,  numba_py:  1.9449
numba_py_sorted:  1.6463,  numba_nm:  0.1661,  numba_nm_sorted:  0.1651

从时间上看,numba_nm 实施是最快的,它比第二快(py.zippy.tuples_cached)快 15-100x 倍。对于 1D 情况,它具有与标准 np.searchsorted 相当的速度(1.85x 较慢)。此外,_sorted 风味似乎并没有改善情况(即使用有关正在排序的插入数组的信息)。

cmp_numba machine-code 编译的方法似乎比 py.zip 执行相同算法但在纯 python 中平均快 1.5x 倍.由于平均最大 equal-key 深度在 15-18 左右,因此 numba 在这里没有获得太多加速。如果深度是数百,那么 numba 代码可能会有很大的加速。

对于密钥长度 <= 100.

的情况,

py.tuples_cached 策略比 py.zip 更快

而且 np.lexsort 似乎实际上很慢,要么它没有针对只有两列的情况进行优化,要么它花时间进行预处理,比如将行拆分为列表,或者它确实non-lazy 字典顺序比较,最后一种情况可能是真正的原因,因为 lexsort 随着键长度的增长而变慢。

策略 np.nonzero 也是 non-lazy 因此工作速度也很慢,并且随着密钥长度的增长而减慢(但速度没有 np.lexsort 那样快)。

上面的时间可能不准确,因为我的 CPU 会在过热时随机降低核心频率 2-2.3 倍,而且经常过热是因为它内部有一个强大的 CPU笔记本电脑。

这里有两件事可以帮助您:(1) 您可以对结构化数组进行排序和搜索,以及 (2) 如果您有可以映射到整数的有限集合,您可以利用它来发挥自己的优势。

一维视图

假设您有一个要插入的字符串数组:

data = np.array([['a', '1'], ['a', 'z'], ['b', 'a']], dtype=object)

由于数组永远不会参差不齐,您可以构造一个行大小的数据类型:

dt = np.dtype([('', data.dtype)] * data.shape[1])

用我无耻插上的答案here,你现在可以把原来的二维数组看成一维了:

view = np.ndarray(data.shape[:1], dtype=dt, buffer=data)

现在可以完全直接地进行搜索:

key = np.array([('a', 'a')], dtype=dt)
index = np.searchsorted(view, key)

您甚至可以使用适当的最小值找到不完整元素的插入索引。对于字符串,这将是 ''.

比较更快

如果您不必检查 dtype 的每个字段,您可能会从比较中得到更好的结果。您可以使用单个同构字段制作类似的数据类型:

dt2 = np.dtype([('row', data.dtype, data.shape[1])])

构建视图与之前相同:

view = np.ndarray(data.shape[:1], dtype=dt2, buffer=data)

这次钥匙的做法有点不同(另一个插件):

key = np.array([(['a', 'a'],)], dtype=dt2)

此方法对对象施加的排序顺序不正确:。我在这里留下一个参考,以防链接的问题得到修复。另外,它对整数排序还是很有用的。

整数映射

如果要搜索的对象数量有限,将它们映射为整数会更容易:

idata = np.empty(data.shape, dtype=int)
keys = [None] * data.shape[1]     # Map index to key per column
indices = [None] * data.shape[1]  # Map key to index per column
for i in range(data.shape[1]):
    keys[i], idata[:, i] = np.unique(data[:, i], return_inverse=True)
    indices[i] = {k: i for i, k in enumerate(keys[i])}  # Assumes hashable objects

idt = np.dtype([('row', idata.dtype, idata.shape[1])])
view = idata.view(idt).ravel()

这仅在 data 实际上包含每列中所有可能的键时才有效。否则,您将不得不通过其他方式获得正向和反向映射。建立后,设置密钥就简单多了,只需要 indices:

key = np.array([index[k] for index, k in zip(indices, ['a', 'a'])])

进一步改进

如果您拥有的类别数不超过 8 个,并且每个类别的元素不超过 256 个,您可以通过将所有内容放入一个 np.uint64 左右的元素中来构建更好的散列。

k = math.ceil(math.log(data.shape[1], 2))  # math.log provides base directly
assert 0 < k <= 64
idata = np.empty((data.shape[:1], k), dtype=np.uint8)
...
idata = idata.view(f'>u{k}').ravel()

钥匙的制作方法也类似:

key = np.array([index[k] for index, k in zip(indices, ['a', 'a'])]).view(f'>u{k}')

计时

我已经使用随机打乱的字符串为此处显示的方法(不是其他答案)计时。关键时序参数是:

  • M: 行数: 10**{2, 3, 4, 5}
  • N: 列数: 2**{3, 4, 5, 6}
  • K:要插入的元素数:1, 10, M // 10
  • 方法:individual_fieldscombined_fieldint_mappingint_packing。功能如下所示。

对于最后两种方法,我假设您会将数据 pre-convert 放入映射的 dtype,而不是搜索键。因此,我传入转换后的数据,但对密钥的转换进行计时。

import numpy as np
from math import ceil, log

def individual_fields(data, keys):
    dt = [('', data.dtype)] * data.shape[1]
    dview = np.ndarray(data.shape[:1], dtype=dt, buffer=data)
    kview = np.ndarray(keys.shape[:1], dtype=dt, buffer=keys)
    return np.searchsorted(dview, kview)

def combined_fields(data, keys):
    dt = [('row', data.dtype, data.shape[1])]
    dview = np.ndarray(data.shape[:1], dtype=dt, buffer=data)
    kview = np.ndarray(keys.shape[:1], dtype=dt, buffer=keys)
    return np.searchsorted(dview, kview)

def int_mapping(idata, keys, indices):
    idt = np.dtype([('row', idata.dtype, idata.shape[1])])
    dview = idata.view(idt).ravel()
    kview = np.empty(keys.shape[0], dtype=idt)
    for i, (index, key) in enumerate(zip(indices, keys.T)):
        kview['row'][:, i] = [index[k] for k in key]
    return np.searchsorted(dview, kview)

def int_packing(idata, keys, indices):
    idt = f'>u{idata.shape[1]}'
    dview = idata.view(idt).ravel()
    kview = np.empty(keys.shape, dtype=np.uint8)
    for i, (index, key) in enumerate(zip(indices, keys.T)):
        kview[:, i] = [index[k] for k in key]
    kview = kview.view(idt).ravel()
    return np.searchsorted(dview, kview)

时间码:

from math import ceil, log
from string import ascii_lowercase
from timeit import Timer

def time(m, n, k, fn, *args):
    t = Timer(lambda: fn(*args))
    s = t.autorange()[0]
    print(f'M={m}; N={n}; K={k} {fn.__name__}: {min(t.repeat(5, s)) / s}')

selection = np.array(list(ascii_lowercase), dtype=object)
for lM in range(2, 6):
    M = 10**lM
    for lN in range(3, 6):
        N = 2**lN
        data = np.random.choice(selection, size=(M, N))
        np.ndarray(data.shape[0], dtype=[('', data.dtype)] * data.shape[1], buffer=data).sort()
        idata = np.array([[ord(a) - ord('a') for a in row] for row in data], dtype=np.uint8)
        ikeys = [selection] * data.shape[1]
        indices = [{k: i for i, k in enumerate(selection)}] * data.shape[1]
        for K in (1, 10, M // 10):
            key = np.random.choice(selection, size=(K, N))
            time(M, N, K, individual_fields, data, key)
            time(M, N, K, combined_fields, data, key)
            time(M, N, K, int_mapping, idata, key, indices)
            if N <= 8:
                time(M, N, K, int_packing, idata, key, indices)

结果:

M=100(单位=us)

   |                           K                           |
   +---------------------------+---------------------------+
N  |             1             |            10             |
   +------+------+------+------+------+------+------+------+
   |  IF  |  CF  |  IM  |  IP  |  IF  |  CF  |  IM  |  IP  |
---+------+------+------+------+------+------+------+------+
 8 | 25.9 | 18.6 | 52.6 | 48.2 | 35.8 | 22.7 | 76.3 | 68.2 | 
16 | 40.1 | 19.0 | 87.6 |  --  | 51.1 | 22.8 | 130. |  --  |
32 | 68.3 | 18.7 | 157. |  --  | 79.1 | 22.4 | 236. |  --  |
64 | 125. | 18.7 | 290. |  --  | 135. | 22.4 | 447. |  --  |
---+------+------+------+------+------+------+------+------+

M=1000(单位=us)

   |                                         K                                         |
   +---------------------------+---------------------------+---------------------------+
N  |             1             |            10             |            100            |
   +------+------+------+------+------+------+------+------+------+------+------+------+
   |  IF  |  CF  |  IM  |  IP  |  IF  |  CF  |  IM  |  IP  |  IF  |  CF  |  IM  |  IP  |
---+------+------+------+------+------+------+------+------+------+------+------+------+
 8 | 26.9 | 19.1 | 55.0 | 55.0 | 44.8 | 25.1 | 79.2 | 75.0 | 218. | 74.4 | 305. | 250. |
16 | 41.0 | 19.2 | 90.5 |  --  | 59.3 | 24.6 | 134. |  --  | 244. | 79.0 | 524. |  --  | 
32 | 68.5 | 19.0 | 159. |  --  | 87.4 | 24.7 | 241. |  --  | 271. | 80.5 | 984. |  --  |
64 | 128. | 19.7 | 312. |  --  | 168. | 26.0 | 549. |  --  | 396. | 7.78 | 2.0k |  --  |
---+------+------+------+------+------+------+------+------+------+------+------+------+

M=10K(单位=us)

   |                                         K                                         |
   +---------------------------+---------------------------+---------------------------+
N  |             1             |            10             |           1000            |
   +------+------+------+------+------+------+------+------+------+------+------+------+
   |  IF  |  CF  |  IM  |  IP  |  IF  |  CF  |  IM  |  IP  |  IF  |  CF  |  IM  |  IP  |
---+------+------+------+------+------+------+------+------+------+------+------+------+
 8 | 28.8 | 19.5 | 54.5 | 107. | 57.0 | 27.2 | 90.5 | 128. | 3.2k | 762. | 2.7k | 2.1k |
16 | 42.5 | 19.6 | 90.4 |  --  | 73.0 | 27.2 | 140. |  --  | 3.3k | 752. | 4.6k |  --  |
32 | 73.0 | 19.7 | 164. |  --  | 104. | 26.7 | 246. |  --  | 3.4k | 803. | 8.6k |  --  |
64 | 135. | 19.8 | 302. |  --  | 162. | 26.1 | 466. |  --  | 3.7k | 791. | 17.k |  --  |
---+------+------+------+------+------+------+------+------+------+------+------+------+

individual_fields (IF) 通常是最快的工作方法。它的复杂性与列数成正比。不幸的是 combined_fields (CF) 不适用于对象数组。否则,它不仅是最快的方法,而且不会随着列的增加而增加复杂性。

我认为会更快的所有技术都不是,因为将 python 对象映射到键很慢(例如,打包 int 数组的实际查找比结构化数组快得多)。

参考文献

为了让这段代码正常工作,我不得不问以下其他问题:

  • View object array under different dtype