仅当行满足 Pandas 的特定条件时才计算每组的滚动函数

Compute rolling function per group only if rows meet certain conditions with Pandas

假设我有以下数据框(这里是简化版):

| id | date | cond | ret |
| -- | ---- | ---- | --- |
| a  | t1   | 1    | n1  |
| a  | t2   | 0    | n2  |
| a  | t3   | 0    | n3  |
| a  | t4   | 1    | n4  |
| a  | t5   | 0    | n5  |
| a  | t6   | 0    | n6  |
| a  | t7   | 0    | n7  |
| a  | t8   | 1    | n8  |
| a  | t9   | 0    | n9  |
| a  | t10  | 1    | n10 |
| b  | t1   | 1    | n11 |
| b  | t2   | 0    | n12 |
| b  | t3   | 0    | n13 |
| b  | t4   | 1    | n14 |
| b  | t5   | 0    | n15 |
| b  | t6   | 0    | n16 |
| b  | t7   | 1    | n17 |
| b  | t8   | 0    | n18 |
| b  | t9   | 1    | n19 |
| b  | t10  | 0    | n20 |

我有兴趣通过 id 计算任意 window 的滚动标准偏差(本例中假设为 3)。但是,我只想为满足 cond==1.

的那些行计算它

我想要的输出是这样的:

| id | date | cond | ret |      r_std      |
| -- | ---- | ---- | --- | --------------- |
| a  | t1   | 1    | n1  | nan             |
| a  | t2   | 0    | n2  | nan             |
| a  | t3   | 0    | n3  | nan             |
| a  | t4   | 1    | n4  | std(n2,n3,n4)   |
| a  | t5   | 0    | n5  | nan             |
| a  | t6   | 0    | n6  | nan             |
| a  | t7   | 0    | n7  | nan             |
| a  | t8   | 1    | n8  | std(n6,n7,n8)   |
| a  | t9   | 0    | n9  | nan             |
| a  | t10  | 1    | n10 | std(n8,n9,n10)  |
| b  | t1   | 1    | n11 | nan             |
| b  | t2   | 0    | n12 | nan             |
| b  | t3   | 0    | n13 | nan             |
| b  | t4   | 1    | n14 | std(n12,n13,n14)|
| b  | t5   | 0    | n15 | nan             |
| b  | t6   | 0    | n16 | nan             |
| b  | t7   | 1    | n17 | std(n15,n16,n17)|
| b  | t8   | 0    | n18 | nan             |
| b  | t9   | 1    | n19 | std(n17,n18,n19)|
| b  | t10  | 0    | n20 | nan             |

或者这个:

| id | date | cond | ret |      r_std      |
| -- | ---- | ---- | --- | --------------- |
| a  | t4   | 1    | n4  | std(n2,n3,n4)   |
| a  | t8   | 1    | n8  | std(n6,n7,n8)   |
| a  | t10  | 1    | n10 | std(n8,n9,n10)  |
| b  | t4   | 1    | n14 | std(n12,n13,n14)|
| b  | t7   | 1    | n17 | std(n15,n16,n17)|
| b  | t9   | 1    | n19 | std(n17,n18,n19)|

我的第一次尝试是:

df.loc[df['cond']==1, 'r_std'] = df.loc[df['cond']==1].groupby('id')['ret']].apply(lambda x : x.rolling(window=3).std())

但这不起作用,因为它仅计算由 .loc 确定的数据集切片的滚动标准偏差。这是我从上面的代码中得到的:

| id | date | cond | ret |      r_std      |
| -- | ---- | ---- | --- | --------------- |
| a  | t1   | 1    | n1  | nan             |
| a  | t2   | 0    | n2  | nan             |
| a  | t3   | 0    | n3  | nan             |
| a  | t4   | 1    | n4  | nan             |
| a  | t5   | 0    | n5  | nan             |
| a  | t6   | 0    | n6  | nan             |
| a  | t7   | 0    | n7  | nan             |
| a  | t8   | 1    | n8  | std(n1,n4,n8)   |
| a  | t9   | 0    | n9  | nan             |
| a  | t10  | 1    | n10 | std(n4,n8,n10)  |
| b  | t1   | 1    | n11 | nan             |
| b  | t2   | 0    | n12 | nan             |
| b  | t3   | 0    | n13 | nan             |
| b  | t4   | 1    | n14 | nan             |
| b  | t5   | 0    | n15 | nan             |
| b  | t6   | 0    | n16 | nan             |
| b  | t7   | 1    | n17 | std(n11,n14,n17)|
| b  | t8   | 0    | n18 | nan             |
| b  | t9   | 1    | n19 | std(n14,n17,n19)|
| b  | t10  | 0    | n20 | nan             |

我也试过:

df['r_std'] = df.groupby('id')[['cond', 'ret']].apply(lambda x : x[1].rolling(window=3).std() if x[0]==1 else np.nan)

但这会引发错误。

我知道我可以简单地计算每一行的滚动标准差,然后 select 只计算我感兴趣的行,但这是一个非常大的高频数据,效率非常低.

谢谢。

试试 reset_indexwhere

df['new'] = df.groupby('id').ret.rolling(3).std().reset_index(level=0,drop=True).where(df.cond==1)
df
Out[227]: 
   id date  cond  ret  new
0   a   t1     1    0  NaN
1   a   t2     0    1  NaN
2   a   t3     0    2  NaN
3   a   t4     1    3  1.0
4   a   t5     0    4  NaN
5   a   t6     0    5  NaN
6   a   t7     0    6  NaN
7   a   t8     1    7  1.0
8   a   t9     0    8  NaN
9   a  t10     1    9  1.0
10  b   t1     1   10  NaN
11  b   t2     0   11  NaN
12  b   t3     0   12  NaN
13  b   t4     1   13  1.0
14  b   t5     0   14  NaN
15  b   t6     0   15  NaN
16  b   t7     1   16  1.0
17  b   t8     0   17  NaN
18  b   t9     1   18  1.0
19  b  t10     0   19  NaN