Numpy einsum 表现不佳。需要注意什么?

Numpy einsum behaving badly. What to look out for?

当 numpy einsum 抛出错误时通常失败的是什么:

Traceback (most recent call last):
  File "rmse_iter.py", line 30, in <module>
    rmse_out = np.sqrt(np.einsum('ij,ij->i',diffs,diffs)/3.0)
TypeError: invalid data type for einsum

numpy 数组 diff 由两个 pandas 数据帧相减产生,并且仅包含 np.float32 类型的数字——没有字符串、nan、+/-inf、或任何其他此类有趣的事情。那么我应该寻找什么?在什么情况下 einsum 通常会以这种方式失败?

这就是我加载和处理数据帧的方式:

df = pd.read_pickle(fn)
df.replace([np.inf, -np.inf], np.nan, inplace=True)
df.dropna(inplace=True)
a = df.values
diffs = a[:,2:27] - a[:,27:]
rmse_out = np.sqrt(np.einsum('ij,ij->i',diffs,diffs)/3.0)

请原谅问题的开放性。感谢 向我介绍了 einsum 魔法。

编辑:

这是我尝试以表格形式包含实际数据的尝试:

        rna     cnv     1_a     2_a     3_a     4_a     5_a     6_a     7_a     8_a     9_a     10_a    11_a    12_a    13_a    14_a    15_a    16_a    17_a    18_a    19_a    20_a    21_a    22_a    23_a    24_a    25_a    1_b     2_b     3_b     4_b     5_b     6_b     7_b     8_b     9_b     10_b    11_b    12_b    13_b    14_b    15_b    16_b    17_b    18_b    19_b    20_b    21_b    22_b    23_b    24_b    25_b
5641095 AP1G1   CCL8    3.588543653488159       10.119391441345215      32.92853546142578       6.307891368865967       -32.6164665222168       -34.94172286987305      -4.913632869720459
      -0.1798282265663147     -0.5144565105438232     12.70481014251709       -37.560791015625        39.83904266357422       32.92853546142578       -0.9303828477859497     -32.6164665222168       -8.661237716674805      31.074113845825195      -0.1798282265663147     -0.5144565105438232     -4.566867828369141      -2.5914463996887207     10.119391441345215      -12.007019996643066     6.307891368865967       -21.65423583984375      -8.217794418334961      2.9316258430480957      27.942243576049805      11.107816696166992      -7.4105706214904785     -1.1366562843322754     17.06450653076172       -7.277851581573486      7.186253547668457       -37.862789154052734     2.21020770072937        -14.829334259033203     5.599830627441406       27.80745506286621       -5.512645244598389      -1.1366562843322754     17.06450653076172       -20.73367691040039      -8.826581001281738      -10.555018424987793     -8.217794418334961
      -6.360044956207275      -1.9607794284820557     6.345422267913818       13.062686920166016
5641105 AP1G1   CCND2   2.3494300842285156      10.119391441345215      27.10674476623535       3.8083128929138184      -70.73456573486328      -39.372581481933594     -8.208958625793457
      -0.1798282265663147     1.082576036453247       12.70481014251709       -63.872154235839844     39.83904266357422       27.10674476623535       0.01608092524111271     -70.73456573486328      -8.661237716674805      43.937278747558594      -0.1798282265663147     1.082576036453247       -3.672504425048828      -3.3072872161865234     10.119391441345215      -8.377813339233398      3.8083128929138184      -26.24537467956543      -10.137262344360352     2.9316258430480957      15.313714027404785      7.0047502517700195      -12.949808120727539     -2.3481321334838867     12.740055084228516      -3.4322025775909424     8.920576095581055       -62.727718353271484     0.2877853512763977      -19.20431137084961      11.22409725189209       27.80745506286621       -1.9983365535736084     -2.3481321334838867     12.740055084228516      -33.702674865722656     -8.826581001281738      -18.610857009887695     -10.137262344360352
     -6.804142475128174      -0.43901631236076355    18.789241790771484      15.554900169372559
5641113 AP1G1   CCNH    4.718714237213135       1230632818573312.0      27.10674476623535       4.7800703048706055      -70.73456573486328      -47.087345123291016     -6.196646690368652
      -1.9009416103363037     474487485104128.0       25.461158752441406      -90.02267456054688      39.83904266357422       27.10674476623535       0.7240228652954102      -70.73456573486328      -14.690686225891113     53.84657669067383       -1.9009416103363037     474487485104128.0       -4.566867828369141      -555133515595776.0      1230632818573312.0      -328591573254144.0      4.7800703048706055      -1088045541490688.0     -10.137262344360352     2.9316258430480957      19.262754440307617      11.107816696166992      -12.949808120727539     -2.3481321334838867     17.06450653076172       -7.277851581573486      17.50507164001465       -45.33726501464844      0.9687032103538513      -33.4061164855957       8564995327524864.0      38.147640228271484      -3.5528361797332764     -2.3481321334838867     17.06450653076172       -33.702674865722656     -8.826581001281738      -27.176956176757812     -10.137262344360352
     -6.431360721588135      -0.43901631236076355    3244183414374400.0      15.554900169372559

事实证明,使用 = df.values 从 df 中提取值不允许将字符串强制转换为 np.nan,这显然在我的原始 df 中。这就是为什么我尝试从 df.values 创建的数组中进行类型转换然后切片所有这些值失败的原因——这些项目只是保持 "object".

为了解决这个问题,我只是从原始 df 中选择了数字列并将它们发送到矩阵:

a= df[df.columns[2:]].as_matrix()

然后我确保在 diff 操作中更新索引,因为列索引向后移动了两个:

diffs = a[:,:25] - a[:,25:]

要点:当 einsum 表现不佳时,在您的数组中查找字符串或 "objects",否则它们不是 float32 或 float64。