Pandas idxmax() 获取行中最大值对应的列名

Pandas idxmax() get column name corresponding to max value in row

假设我有以下 DataFrame Q_df:

        (0, 0)  (0, 1)  (0, 2)  (1, 0)  (1, 1)  (1, 2)  (2, 0)  (2, 1)  (2, 2)
(0, 0)   0.000    0.00     0.0    0.64   0.000     0.0   0.512   0.000     0.0
(0, 1)   0.000    0.00     0.8    0.00   0.512     0.0   0.000   0.512     0.0
(0, 2)   0.000    0.64     0.0    0.00   0.000     0.8   0.000   0.000     1.0
(1, 0)   0.512    0.00     0.0    0.00   0.000     0.8   0.512   0.000     0.0
(1, 1)   0.000    0.64     0.0    0.00   0.000     0.0   0.000   0.512     0.0
(1, 2)   0.000    0.00     0.8    0.64   0.000     0.0   0.000   0.000     1.0
(2, 0)   0.512    0.00     0.0    0.64   0.000     0.0   0.000   0.512     0.0
(2, 1)   0.000    0.64     0.0    0.00   0.512     0.0   0.512   0.000     0.0
(2, 2)   0.000    0.00     0.8    0.00   0.000     0.8   0.000   0.000     0.0

使用以下代码生成:

import numpy as np
import pandas as pd

states = list(itertools.product(range(3), repeat=2))

Q = np.array([[0.000,0.000,0.000,0.640,0.000,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.512,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.000,0.800,0.000,0.000,1.000],
[0.512,0.000,0.000,0.000,0.000,0.800,0.512,0.000,0.000],
[0.000,0.640,0.000,0.000,0.000,0.000,0.000,0.512,0.000],
[0.000,0.000,0.800,0.640,0.000,0.000,0.000,0.000,1.000],
[0.512,0.000,0.000,0.640,0.000,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.512,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.000,0.800,0.000,0.000,0.000]])

Q_df = pd.DataFrame(index=states, columns=states, data=Q)

对于Q的每一行,我想得到该行中最大值对应的列名。如果我尝试

policy = Q_df.idxmax()

然后生成的系列如下所示:

(0, 0)    (1, 0)
(0, 1)    (0, 2)
(0, 2)    (0, 1)
(1, 0)    (0, 0)
(1, 1)    (0, 1)
(1, 2)    (0, 2)
(2, 0)    (0, 0)
(2, 1)    (0, 1)
(2, 2)    (0, 2)

第一行看起来不错:第一行的最大元素是 0.64 并且出现在第 (1,0) 列中。第二个也是。但是,对于第三行,最大元素是 0.8 并且出现在列 (1,2) 中,所以我希望 policy 中的相应值是 (1,2),而不是 (0,1).

知道这里出了什么问题吗?

IIUC,可以在idxmax中使用axis=1:

policy = Q_df.idxmax(axis=1)

(0, 0)    (1, 0)
(0, 1)    (0, 2)
(0, 2)    (2, 2)
(1, 0)    (1, 2)
(1, 1)    (0, 1)
(1, 2)    (2, 2)
(2, 0)    (1, 0)
(2, 1)    (0, 1)
(2, 2)    (0, 2)
dtype: object