清理 NaN 的 np 数组,同时相应地删除其他数组中的条目

Clean np array of NaN while deleting entries in other array accordingly

我有两个 numpy 数组,其中一个包含大约 1% 的 NaN。

a = np.array([-2,5,nan,6])
b = np.array([2,3,1,0])

我想使用 sklearnmean_squared_error.

计算 ab 的均方误差

所以我的问题是,从 a 中删除所有 NaN,同时尽可能高效地从 b 中删除所有对应条目的 pythonic 方法是什么?

您可以简单地使用原始 NumPy 的 np.nanmean 来达到这个目的:

In [136]: np.nanmean((a-b)**2)
Out[136]: 18.666666666666668

如果这不存在,或者您真的想使用 sklearn 方法,您可以创建一个 mask 来索引 NaN:

In [148]: mask = ~np.isnan(a)

In [149]: mean_squared_error(a[mask], b[mask])
Out[149]: 18.666666666666668