pandas 更优化的解决方案应用行

A more optimized solution to pandas apply row-wise

我有这段代码可以对 DataFrame 进行一些分析。 both_profitableTrue 当且仅当该行中的 long_profitableshort_profitable 都是 True。但是,DataFrame 非常大,在 axis=1 上使用 pandas apply 比我想要的更费力。

output["long_profitable"] = (
    df[[c for c in df.columns if "long_profit" in c]].ge(target).any(axis=1)
)
output["short_profitable"] = (
    df[[c for c in df.columns if "short_profit" in c]].ge(target).any(axis=1)
)
output["both_profitable"] = output.apply(
    lambda x: True if x["long_profitable"] and x["short_profitable"] else False,
    axis=1,
)

是否有 simpler/more 优化的方法来实现同样的目标?

您应该在列上使用 eq 方法:

output["both_profitable"] = output["long_profitable"].eq(output["short_profitable"])

或者由于两列都是布尔值,您可以使用按位 & 运算符:

output["both_profitable"]  = output["long_profitable"] & output["short_profitable"]

此外,您可以使用 str.contains + loc,而不是对 df 的 select 列的列表理解:

output["long_profitable"] = df.loc[:, df.columns.str.contains('long_profit')].ge(target).any(axis=1)
output["short_profitable"] = df.loc[:, df.columns.str.contains('short_profit')].ge(target).any(axis=1)

both_profitable is True if and only if both long_profitable and short_profitable in that row are True

换句话说,both_profitable是两列的布尔AND运算结果。

这可以通过多种方式实现:

output['long_profitable'] & output['short_profitable']
# for any number of boolean columns, all of which we want to AND
cols = ['long_profitable', 'short_profitable']
output[cols].all(axis=1)
# same logic, using prod() -- this is just for fun; use all() instead
output[cols].prod(axis=1).astype(bool)

当然,您可以将以上任何一项分配给新列:

output_modified = output.assign(both_profitable=...)

注意:如果您有 AND-ing 多列,则第二和第三种形式特别有用。

时机

n = 10_000_000
np.random.seed(0)
output = pd.DataFrame({
    'long_profitable': np.random.randint(0, 2, n, dtype=bool),
    'short_profitable': np.random.randint(0, 2, n, dtype=bool),
})

%timeit output['long_profitable'] & output['short_profitable']
# 4.52 ms ± 41.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit output[cols].all(axis=1)
# 18.6 ms ± 53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit output[cols].prod(axis=1).astype(bool)
# 71.6 ms ± 375 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)