在 python pandas 中应用带有移位函数的 lambda 是否要替换一些空元素

Apply a lambda with a shift function in python pandas were some null elements are to be replaced

我正在尝试在数据框中执行以下操作。 如果 Period 不为 1,则更改 Column Attrition 的值,然后将该行中保留列的值乘以 groupby 上一行中的 attrition 值。我的尝试如下:

import pandas as pd

data = {'Country': ['DE', 'DE', 'DE', 'US', 'US', 'US', 'FR', 'FR', 'FR'],
    'Week': ['201426', '201426', '201426', '201426', '201425', '201425', '201426', '201426', '201426'],
    'Period': [1, 2, 3, 1, 1, 2, 1, 2, 3],
    'Attrition': [0.5,'' ,'' ,0.85 ,0.865,'' ,0.74 ,'','' ],
    'Retention': [0.95,0.85,0.94,0.85,0.97,0.93,0.97,0.93,0.94]}

df = pd.DataFrame(data, columns= ['Country', 'Week', 'Period', 'Attrition','Retention'])
print df

给我这个输出:

  Country    Week  Period Attrition  Retention
0      DE  201426       1       0.5       0.95
1      DE  201426       2                 0.85
2      DE  201426       3                 0.94
3      US  201426       1      0.85       0.85
4      US  201425       1     0.865       0.97
5      US  201425       2                 0.93
6      FR  201426       1      0.74       0.97
7      FR  201426       2                 0.93
8      FR  201426       3                 0.94

以下:

df['Attrition'] = df.groupby(['Country','Week']).apply(lambda x: x.Attrition.shift(1)*x['Retention'] if x.Period != 1 else x.Attrition)

print df

给我这个错误:

df['Attrition'] = df.groupby(['Country','Week']).apply(lambda x: x.Attrition.shift(1)*x['Retention'] if x.Period != 1 else x.Attrition)

ValueError: 具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()

更新:完整编译解决方案

下面是我之后的完整工作解决方案,基本上是使用 Primer 的答案,但添加了一个 while 循环以保持 运行 数据框列上的 Lambda 函数,直到不再有 NaN。

import pandas as pd
import numpy as np

data = {'Country': ['DE', 'DE', 'DE', 'US', 'US', 'US', 'FR', 'FR', 'FR'],
    'Week': ['201426', '201426', '201426', '201426', '201425', '201425', '201426', '201426', '201426'],
    'Period': [1, 2, 3, 1, 1, 2, 1, 2, 3],
    'Attrition': [0.5, '' ,'' ,0.85 ,0.865,'' ,0.74 ,'','' ],
    'Retention': [0.95,0.85,0.94,0.85,0.97,0.93,0.97,0.93,0.94]}

df = pd.DataFrame(data, columns= ['Country', 'Week', 'Period', 'Attrition','Retention'])
print df

输出:开始 DF

  Country    Week  Period Attrition  Retention
0      DE  201426       1       0.5       0.95
1      DE  201426       2                 0.85
2      DE  201426       3                 0.94
3      US  201426       1      0.85       0.85
4      US  201425       1     0.865       0.97
5      US  201425       2                 0.93
6      FR  201426       1      0.74       0.97
7      FR  201426       2                 0.93
8      FR  201426       3                 0.94

解法:

#Replaces empty string with NaNs
df['Attrition'] = df['Attrition'].replace('', np.nan)

#Stores a count of the number of null or NaNs in the column.
ContainsNaN = df['Attrition'].isnull().sum()

#run the loop while there are some NaNs in the column.
while ContainsNaN > 0:    
    df['Attrition'] = df.groupby(['Country','Week']).apply(lambda x: pd.Series(np.where((x.Period != 1), x.Attrition.shift() * x['Retention'], x.Attrition)))        
    ContainsNaN = df['Attrition'].isnull().sum()

print df

输出:结果

  Country    Week  Period Attrition  Retention
0      DE  201426       1       0.5       0.95
1      DE  201426       2     0.425       0.85
2      DE  201426       3    0.3995       0.94
3      US  201426       1      0.85       0.85
4      US  201425       1     0.865       0.97
5      US  201425       2   0.80445       0.93
6      FR  201426       1      0.74       0.97
7      FR  201426       2    0.6882       0.93
8      FR  201426       3  0.646908       0.94

首先,您的 Attrition 列混合了数字数据和空字符串 '',这通常不是一个好主意,应该在尝试对此列进行计算之前解决:

df.loc[df['Attrition'] == '', 'Attrition'] = pd.np.nan
df['Attrition'] = df.Attrition.astype('float')

您得到的错误是因为您在 .apply 中的条件:x.Period != 1 产生了一个布尔数组:

0    False
1     True
2     True
3    False
4    False
5     True
6    False
7     True
8     True
Name: Period, dtype: bool

.apply 不知道如何处理,因为它含糊不清(即在这种情况下什么应该是 True?)。

您可以考虑 numpy.where 完成此任务:

import numpy as np
g = df.groupby(['Country','Week'], as_index=0, group_keys=0)
df['Attrition'] = g.apply(lambda x: pd.Series(np.where((x.Period != 1), x.Attrition.shift() * x['Retention'], x.Attrition)).fillna(method='ffill')).values
df

产量:

  Country    Week  Period  Attrition  Retention
0      DE  201426       1      0.500       0.95
1      DE  201426       2      0.425       0.85
2      DE  201426       3      0.425       0.94
3      US  201426       1      0.740       0.85
4      US  201425       1      0.688       0.97
5      US  201425       2      0.688       0.93
6      FR  201426       1      0.865       0.97
7      FR  201426       2      0.804       0.93
8      FR  201426       3      0.850       0.94

请注意,我添加了 .fillna 方法,该方法用最后观察到的值填充 NaN