Numpy einsum 表现不佳。需要注意什么?
Numpy einsum behaving badly. What to look out for?
当 numpy einsum 抛出错误时通常失败的是什么:
Traceback (most recent call last):
File "rmse_iter.py", line 30, in <module>
rmse_out = np.sqrt(np.einsum('ij,ij->i',diffs,diffs)/3.0)
TypeError: invalid data type for einsum
numpy 数组 diff
由两个 pandas 数据帧相减产生,并且仅包含 np.float32
类型的数字——没有字符串、nan、+/-inf、或任何其他此类有趣的事情。那么我应该寻找什么?在什么情况下 einsum 通常会以这种方式失败?
这就是我加载和处理数据帧的方式:
df = pd.read_pickle(fn)
df.replace([np.inf, -np.inf], np.nan, inplace=True)
df.dropna(inplace=True)
a = df.values
diffs = a[:,2:27] - a[:,27:]
rmse_out = np.sqrt(np.einsum('ij,ij->i',diffs,diffs)/3.0)
请原谅问题的开放性。感谢 向我介绍了 einsum 魔法。
编辑:
这是我尝试以表格形式包含实际数据的尝试:
rna cnv 1_a 2_a 3_a 4_a 5_a 6_a 7_a 8_a 9_a 10_a 11_a 12_a 13_a 14_a 15_a 16_a 17_a 18_a 19_a 20_a 21_a 22_a 23_a 24_a 25_a 1_b 2_b 3_b 4_b 5_b 6_b 7_b 8_b 9_b 10_b 11_b 12_b 13_b 14_b 15_b 16_b 17_b 18_b 19_b 20_b 21_b 22_b 23_b 24_b 25_b
5641095 AP1G1 CCL8 3.588543653488159 10.119391441345215 32.92853546142578 6.307891368865967 -32.6164665222168 -34.94172286987305 -4.913632869720459
-0.1798282265663147 -0.5144565105438232 12.70481014251709 -37.560791015625 39.83904266357422 32.92853546142578 -0.9303828477859497 -32.6164665222168 -8.661237716674805 31.074113845825195 -0.1798282265663147 -0.5144565105438232 -4.566867828369141 -2.5914463996887207 10.119391441345215 -12.007019996643066 6.307891368865967 -21.65423583984375 -8.217794418334961 2.9316258430480957 27.942243576049805 11.107816696166992 -7.4105706214904785 -1.1366562843322754 17.06450653076172 -7.277851581573486 7.186253547668457 -37.862789154052734 2.21020770072937 -14.829334259033203 5.599830627441406 27.80745506286621 -5.512645244598389 -1.1366562843322754 17.06450653076172 -20.73367691040039 -8.826581001281738 -10.555018424987793 -8.217794418334961
-6.360044956207275 -1.9607794284820557 6.345422267913818 13.062686920166016
5641105 AP1G1 CCND2 2.3494300842285156 10.119391441345215 27.10674476623535 3.8083128929138184 -70.73456573486328 -39.372581481933594 -8.208958625793457
-0.1798282265663147 1.082576036453247 12.70481014251709 -63.872154235839844 39.83904266357422 27.10674476623535 0.01608092524111271 -70.73456573486328 -8.661237716674805 43.937278747558594 -0.1798282265663147 1.082576036453247 -3.672504425048828 -3.3072872161865234 10.119391441345215 -8.377813339233398 3.8083128929138184 -26.24537467956543 -10.137262344360352 2.9316258430480957 15.313714027404785 7.0047502517700195 -12.949808120727539 -2.3481321334838867 12.740055084228516 -3.4322025775909424 8.920576095581055 -62.727718353271484 0.2877853512763977 -19.20431137084961 11.22409725189209 27.80745506286621 -1.9983365535736084 -2.3481321334838867 12.740055084228516 -33.702674865722656 -8.826581001281738 -18.610857009887695 -10.137262344360352
-6.804142475128174 -0.43901631236076355 18.789241790771484 15.554900169372559
5641113 AP1G1 CCNH 4.718714237213135 1230632818573312.0 27.10674476623535 4.7800703048706055 -70.73456573486328 -47.087345123291016 -6.196646690368652
-1.9009416103363037 474487485104128.0 25.461158752441406 -90.02267456054688 39.83904266357422 27.10674476623535 0.7240228652954102 -70.73456573486328 -14.690686225891113 53.84657669067383 -1.9009416103363037 474487485104128.0 -4.566867828369141 -555133515595776.0 1230632818573312.0 -328591573254144.0 4.7800703048706055 -1088045541490688.0 -10.137262344360352 2.9316258430480957 19.262754440307617 11.107816696166992 -12.949808120727539 -2.3481321334838867 17.06450653076172 -7.277851581573486 17.50507164001465 -45.33726501464844 0.9687032103538513 -33.4061164855957 8564995327524864.0 38.147640228271484 -3.5528361797332764 -2.3481321334838867 17.06450653076172 -33.702674865722656 -8.826581001281738 -27.176956176757812 -10.137262344360352
-6.431360721588135 -0.43901631236076355 3244183414374400.0 15.554900169372559
事实证明,使用 = df.values 从 df 中提取值不允许将字符串强制转换为 np.nan,这显然在我的原始 df 中。这就是为什么我尝试从 df.values 创建的数组中进行类型转换然后切片所有这些值失败的原因——这些项目只是保持 "object".
为了解决这个问题,我只是从原始 df 中选择了数字列并将它们发送到矩阵:
a= df[df.columns[2:]].as_matrix()
然后我确保在 diff 操作中更新索引,因为列索引向后移动了两个:
diffs = a[:,:25] - a[:,25:]
要点:当 einsum 表现不佳时,在您的数组中查找字符串或 "objects",否则它们不是 float32 或 float64。
当 numpy einsum 抛出错误时通常失败的是什么:
Traceback (most recent call last):
File "rmse_iter.py", line 30, in <module>
rmse_out = np.sqrt(np.einsum('ij,ij->i',diffs,diffs)/3.0)
TypeError: invalid data type for einsum
numpy 数组 diff
由两个 pandas 数据帧相减产生,并且仅包含 np.float32
类型的数字——没有字符串、nan、+/-inf、或任何其他此类有趣的事情。那么我应该寻找什么?在什么情况下 einsum 通常会以这种方式失败?
这就是我加载和处理数据帧的方式:
df = pd.read_pickle(fn)
df.replace([np.inf, -np.inf], np.nan, inplace=True)
df.dropna(inplace=True)
a = df.values
diffs = a[:,2:27] - a[:,27:]
rmse_out = np.sqrt(np.einsum('ij,ij->i',diffs,diffs)/3.0)
请原谅问题的开放性。感谢
编辑:
这是我尝试以表格形式包含实际数据的尝试:
rna cnv 1_a 2_a 3_a 4_a 5_a 6_a 7_a 8_a 9_a 10_a 11_a 12_a 13_a 14_a 15_a 16_a 17_a 18_a 19_a 20_a 21_a 22_a 23_a 24_a 25_a 1_b 2_b 3_b 4_b 5_b 6_b 7_b 8_b 9_b 10_b 11_b 12_b 13_b 14_b 15_b 16_b 17_b 18_b 19_b 20_b 21_b 22_b 23_b 24_b 25_b
5641095 AP1G1 CCL8 3.588543653488159 10.119391441345215 32.92853546142578 6.307891368865967 -32.6164665222168 -34.94172286987305 -4.913632869720459
-0.1798282265663147 -0.5144565105438232 12.70481014251709 -37.560791015625 39.83904266357422 32.92853546142578 -0.9303828477859497 -32.6164665222168 -8.661237716674805 31.074113845825195 -0.1798282265663147 -0.5144565105438232 -4.566867828369141 -2.5914463996887207 10.119391441345215 -12.007019996643066 6.307891368865967 -21.65423583984375 -8.217794418334961 2.9316258430480957 27.942243576049805 11.107816696166992 -7.4105706214904785 -1.1366562843322754 17.06450653076172 -7.277851581573486 7.186253547668457 -37.862789154052734 2.21020770072937 -14.829334259033203 5.599830627441406 27.80745506286621 -5.512645244598389 -1.1366562843322754 17.06450653076172 -20.73367691040039 -8.826581001281738 -10.555018424987793 -8.217794418334961
-6.360044956207275 -1.9607794284820557 6.345422267913818 13.062686920166016
5641105 AP1G1 CCND2 2.3494300842285156 10.119391441345215 27.10674476623535 3.8083128929138184 -70.73456573486328 -39.372581481933594 -8.208958625793457
-0.1798282265663147 1.082576036453247 12.70481014251709 -63.872154235839844 39.83904266357422 27.10674476623535 0.01608092524111271 -70.73456573486328 -8.661237716674805 43.937278747558594 -0.1798282265663147 1.082576036453247 -3.672504425048828 -3.3072872161865234 10.119391441345215 -8.377813339233398 3.8083128929138184 -26.24537467956543 -10.137262344360352 2.9316258430480957 15.313714027404785 7.0047502517700195 -12.949808120727539 -2.3481321334838867 12.740055084228516 -3.4322025775909424 8.920576095581055 -62.727718353271484 0.2877853512763977 -19.20431137084961 11.22409725189209 27.80745506286621 -1.9983365535736084 -2.3481321334838867 12.740055084228516 -33.702674865722656 -8.826581001281738 -18.610857009887695 -10.137262344360352
-6.804142475128174 -0.43901631236076355 18.789241790771484 15.554900169372559
5641113 AP1G1 CCNH 4.718714237213135 1230632818573312.0 27.10674476623535 4.7800703048706055 -70.73456573486328 -47.087345123291016 -6.196646690368652
-1.9009416103363037 474487485104128.0 25.461158752441406 -90.02267456054688 39.83904266357422 27.10674476623535 0.7240228652954102 -70.73456573486328 -14.690686225891113 53.84657669067383 -1.9009416103363037 474487485104128.0 -4.566867828369141 -555133515595776.0 1230632818573312.0 -328591573254144.0 4.7800703048706055 -1088045541490688.0 -10.137262344360352 2.9316258430480957 19.262754440307617 11.107816696166992 -12.949808120727539 -2.3481321334838867 17.06450653076172 -7.277851581573486 17.50507164001465 -45.33726501464844 0.9687032103538513 -33.4061164855957 8564995327524864.0 38.147640228271484 -3.5528361797332764 -2.3481321334838867 17.06450653076172 -33.702674865722656 -8.826581001281738 -27.176956176757812 -10.137262344360352
-6.431360721588135 -0.43901631236076355 3244183414374400.0 15.554900169372559
事实证明,使用 = df.values 从 df 中提取值不允许将字符串强制转换为 np.nan,这显然在我的原始 df 中。这就是为什么我尝试从 df.values 创建的数组中进行类型转换然后切片所有这些值失败的原因——这些项目只是保持 "object".
为了解决这个问题,我只是从原始 df 中选择了数字列并将它们发送到矩阵:
a= df[df.columns[2:]].as_matrix()
然后我确保在 diff 操作中更新索引,因为列索引向后移动了两个:
diffs = a[:,:25] - a[:,25:]
要点:当 einsum 表现不佳时,在您的数组中查找字符串或 "objects",否则它们不是 float32 或 float64。