如何优化这个 pandas 可迭代对象

How to optimize this pandas iterable

我有以下方法,其中我根据一组层次规则消除数据框中的重叠间隔:

def disambiguate(arg):

    arg['length'] = (arg.end - arg.begin).abs()

    df = arg[['begin', 'end', 'note_id', 'score', 'length']].copy()

    data = []
    out = pd.DataFrame()
    for row in df.itertuples():

        test = df[df['note_id']==row.note_id].copy()

        # get overlapping intervals: 
        # 
        iix = pd.IntervalIndex.from_arrays(test.begin.apply(pd.to_numeric), test.end.apply(pd.to_numeric), closed='neither')
        span_range = pd.Interval(row.begin, row.end)
        fx = test[iix.overlaps(span_range)].copy()

        maxLength = fx['length'].max()
        minLength = fx['length'].min()
        maxScore = abs(float(fx['score'].max()))
        minScore = abs(float(fx['score'].min()))

        # filter out overlapping rows via hierarchy 
        if maxScore > minScore:
            fx = fx[fx['score'] == maxScore]

        elif maxLength > minLength:
            fx = fx[fx['length'] == minScore]

        data.append(fx)

    out = pd.concat(data, axis=0)

    # randomly reindex to keep random row when dropping remaining duplicates: https://gist.github.com/cadrev/6b91985a1660f26c2742
    out.reset_index(inplace=True)
    out = out.reindex(np.random.permutation(out.index))

    return out.drop_duplicates(subset=['begin', 'end', 'note_id'])

这工作正常,除了我正在迭代的数据帧每个都有超过 10 万行,所以这需要很长时间才能完成。我在 Jupyter 中使用 %prun 对各种方法进行了计时,似乎耗尽处理时间的方法是 series.py:3719(apply) ... 注意:我尝试使用 modin.pandas,但那是导致更多问题(我一直收到 Interval 的错误,需要 left 小于 right 的值,我无法弄清楚:我可能会提交 GitHub在那里发布)。

我正在寻找一种优化方法,例如使用矢量化,但老实说,我完全不知道如何将其转换为矢量化形式。

这是我的数据示例:

begin,end,note_id,score
0,9,0365,1
10,14,0365,1
25,37,0365,0.7
28,37,0365,1
38,42,0365,1
53,69,0365,0.7857142857142857
56,60,0365,1
56,69,0365,1
64,69,0365,1
83,86,0365,1
91,98,0365,0.8333333333333334
101,108,0365,1
101,127,0365,1
112,119,0365,1
112,127,0365,0.8571428571428571
120,127,0365,1
163,167,0365,1
196,203,0365,1
208,216,0365,1
208,223,0365,1
208,231,0365,1
208,240,0365,0.6896551724137931
217,223,0365,1
217,231,0365,1
224,231,0365,1
246,274,0365,0.7692307692307693
252,274,0365,1
263,274,0365,0.8888888888888888
296,316,0365,0.7222222222222222
301,307,0365,1
301,316,0365,1
301,330,0365,0.7307692307692307
301,336,0365,0.78125
308,316,0365,1
308,323,0365,1
308,330,0365,1
308,336,0365,1
317,323,0365,1
317,336,0365,1
324,330,0365,1
324,336,0365,1
361,418,0365,0.7368421052631579
370,404,0365,0.7111111111111111
370,418,0365,0.875
383,418,0365,0.8285714285714286
396,404,0365,1
396,418,0365,0.8095238095238095
405,418,0365,0.8333333333333334
432,453,0365,0.7647058823529411
438,453,0365,1
438,458,0365,0.7222222222222222

我想我知道问题出在哪里:我对 note_id 的过滤不正确,因此遍历了整个数据帧。

应该是:

    cases = set(df['note_id'].tolist())

    for case in cases:
        test = df[df['note_id']==case].copy()
        for row in df.itertuples():

            # get overlapping intervals: 
            # 
            iix = pd.IntervalIndex.from_arrays(test.begin, test.end, closed='neither')
            span_range = pd.Interval(row.begin, row.end)
            fx = test[iix.overlaps(span_range)].copy()

            maxLength = fx['length'].max()
            minLength = fx['length'].min()
            maxScore = abs(float(fx['score'].max()))
            minScore = abs(float(fx['score'].min()))

            if maxScore > minScore:
                fx = fx[fx['score'] == maxScore]

            elif maxLength > minLength:
                fx = fx[fx['length'] == maxLength]

            data.append(fx)

        out = pd.concat(data, axis=0)

为了测试一张纸条,在我停止遍历整个未过滤的数据帧之前,它花费了 16 多分钟。现在,是 28 秒!