MultiIndex 上的 DataFrame groupby() 然后应用于多个列会导致广播问题
DataFrame groupby() on MultiIndex then apply on multiple columns leads to broadcasting problems
这是设置:
arrays = [["2010-01-01","2010-01-01","2010-01-02","2010-01-02","2010-01-03","2010-01-03"],
["MSFT", "AAPL", "MSFT", "AAPL","MSFT", "AAPL"]]
tuples = list(zip(*arrays))
index = pd.MultiIndex.from_tuples(tuples, names=["date", "symbol"])
df = pd.DataFrame(data=np.random.randn(6, 4), index=index, columns=["high", "low", "open", "close"])
def fn_sum(close, high, low):
return close+high+low
def fn_plus(close):
return close+1
DF 看起来像这样:
date symbol high low open close
2010-01-01 MSFT 1.144042 0.889603 -0.193715 1.005927
AAPL 0.433530 -0.291510 1.420505 0.326206
2010-01-02 MSFT -1.509419 -0.273476 -0.620735 -0.205946
AAPL 0.454401 -0.085008 0.686485 1.309894
2010-01-03 MSFT 1.487588 -0.777500 -0.218993 -1.242664
AAPL -0.456024 -0.819463 -2.224953 1.263124
我想像这样使用 groupby()、apply() 方式对所有交易品种使用技术分析函数:
df["1"] = df.groupby(level="symbol").apply(lambda x: fn_sum(x["close"], x["high"], x["low"]))
这会导致广播错误:
ValueError: operands could not be broadcast together with shapes (6,2) (3,) (6,2)
虽然在单个列上执行相同的操作仍然有效:
df["2"] = df.groupby(level="symbol").close.apply(lambda x: fn_plus(x))
问题:
那么,当在多个列上使用 apply 并将它们组合回 DataFrame 而没有广播问题时,我该如何使它工作?
另外,我非常感谢能像上面那样使用 MultiIndex DF 的更好实现。
更多上下文:我想使用 TA-lib 包中的技术分析功能。参见:https://mrjbq7.github.io/ta-lib/func_groups/volatility_indicators.html
函数看起来像这样(例如):
ATR(high, low, close[, timeperiod=?])
Average True Range (Volatility Indicators)
Inputs: prices: ['high', 'low', 'close']
Parameters: timeperiod: 14
Outputs: real
我在人为的例子中遇到了与上面相同的广播错误。
如果需要多列传递给函数使用 DataFrame.join
or DataFrame.assign
:
s = (df.groupby(level="symbol", group_keys=False)
.apply(lambda x: fn_sum(x["close"], x["high"], x["low"])))
df = df.join(s.rename('new'))
#alternative
#df = df.assign(new=s)
print (df)
high low open close new
date symbol
2010-01-01 MSFT -1.085631 0.997345 0.282978 -1.506295 -1.594580
AAPL -0.578600 1.651437 -2.426679 -0.428913 0.643924
2010-01-02 MSFT 1.265936 -0.866740 -0.678886 -0.094709 0.304487
AAPL 1.491390 -0.638902 -0.443982 -0.434351 0.418136
2010-01-03 MSFT 2.205930 2.186786 1.004054 0.386186 4.778903
AAPL 0.737369 1.490732 -0.935834 1.175829 3.403930
如果只有一列使用GroupBy.transform
并在groupby
之后指定列:
df['new1'] = df.groupby(level="symbol")['close'].transform(fn_plus)
print (df)
high low open close new1
date symbol
2010-01-01 MSFT -1.085631 0.997345 0.282978 -1.506295 -0.506295
AAPL -0.578600 1.651437 -2.426679 -0.428913 0.571087
2010-01-02 MSFT 1.265936 -0.866740 -0.678886 -0.094709 0.905291
AAPL 1.491390 -0.638902 -0.443982 -0.434351 0.565649
2010-01-03 MSFT 2.205930 2.186786 1.004054 0.386186 1.386186
AAPL 0.737369 1.490732 -0.935834 1.175829 2.175829
这是设置:
arrays = [["2010-01-01","2010-01-01","2010-01-02","2010-01-02","2010-01-03","2010-01-03"],
["MSFT", "AAPL", "MSFT", "AAPL","MSFT", "AAPL"]]
tuples = list(zip(*arrays))
index = pd.MultiIndex.from_tuples(tuples, names=["date", "symbol"])
df = pd.DataFrame(data=np.random.randn(6, 4), index=index, columns=["high", "low", "open", "close"])
def fn_sum(close, high, low):
return close+high+low
def fn_plus(close):
return close+1
DF 看起来像这样:
date symbol high low open close
2010-01-01 MSFT 1.144042 0.889603 -0.193715 1.005927
AAPL 0.433530 -0.291510 1.420505 0.326206
2010-01-02 MSFT -1.509419 -0.273476 -0.620735 -0.205946
AAPL 0.454401 -0.085008 0.686485 1.309894
2010-01-03 MSFT 1.487588 -0.777500 -0.218993 -1.242664
AAPL -0.456024 -0.819463 -2.224953 1.263124
我想像这样使用 groupby()、apply() 方式对所有交易品种使用技术分析函数:
df["1"] = df.groupby(level="symbol").apply(lambda x: fn_sum(x["close"], x["high"], x["low"]))
这会导致广播错误:
ValueError: operands could not be broadcast together with shapes (6,2) (3,) (6,2)
虽然在单个列上执行相同的操作仍然有效:
df["2"] = df.groupby(level="symbol").close.apply(lambda x: fn_plus(x))
问题:
那么,当在多个列上使用 apply 并将它们组合回 DataFrame 而没有广播问题时,我该如何使它工作?
另外,我非常感谢能像上面那样使用 MultiIndex DF 的更好实现。
更多上下文:我想使用 TA-lib 包中的技术分析功能。参见:https://mrjbq7.github.io/ta-lib/func_groups/volatility_indicators.html
函数看起来像这样(例如):
ATR(high, low, close[, timeperiod=?])
Average True Range (Volatility Indicators)
Inputs: prices: ['high', 'low', 'close'] Parameters: timeperiod: 14 Outputs: real
我在人为的例子中遇到了与上面相同的广播错误。
如果需要多列传递给函数使用 DataFrame.join
or DataFrame.assign
:
s = (df.groupby(level="symbol", group_keys=False)
.apply(lambda x: fn_sum(x["close"], x["high"], x["low"])))
df = df.join(s.rename('new'))
#alternative
#df = df.assign(new=s)
print (df)
high low open close new
date symbol
2010-01-01 MSFT -1.085631 0.997345 0.282978 -1.506295 -1.594580
AAPL -0.578600 1.651437 -2.426679 -0.428913 0.643924
2010-01-02 MSFT 1.265936 -0.866740 -0.678886 -0.094709 0.304487
AAPL 1.491390 -0.638902 -0.443982 -0.434351 0.418136
2010-01-03 MSFT 2.205930 2.186786 1.004054 0.386186 4.778903
AAPL 0.737369 1.490732 -0.935834 1.175829 3.403930
如果只有一列使用GroupBy.transform
并在groupby
之后指定列:
df['new1'] = df.groupby(level="symbol")['close'].transform(fn_plus)
print (df)
high low open close new1
date symbol
2010-01-01 MSFT -1.085631 0.997345 0.282978 -1.506295 -0.506295
AAPL -0.578600 1.651437 -2.426679 -0.428913 0.571087
2010-01-02 MSFT 1.265936 -0.866740 -0.678886 -0.094709 0.905291
AAPL 1.491390 -0.638902 -0.443982 -0.434351 0.565649
2010-01-03 MSFT 2.205930 2.186786 1.004054 0.386186 1.386186
AAPL 0.737369 1.490732 -0.935834 1.175829 2.175829