如果包含任何 nan,则删除整个 3d 数组

Delete whole 3d array if contains any nan

我有一组由 3d ndarray 表示的图像。抽象地说,我想做的是,删除整个图像,如果它的任何像素值是 nan。 假设我们得到了以下 ndarray:

a = np.arange(18).reshape(3, 2, 3)
a = 1.0 * a
a[0][0][1] = np.nan
a[1][0][0] = np.nan
a
[[[ 0. nan  2.]
  [ 3.  4.  5.]]

 [[nan  7.  8.]
  [ 9. 10. 11.]]

 [[12. 13. 14.]
  [15. 16. 17.]]]

现在我想要得到的是一个给定 ndarray returns True, True, False 的函数。为了最终使用np.delete.

我尝试了以下方法,效果很好:

np.delete(a, [np.isnan(image.flatten()).any() for image in a], axis=0)
array([[[12., 13., 14.],
        [15., 16., 17.]]]))

但是,我很难相信numpy中没有比它更高效的函数,而且由于我有很多图像,所以我想尽可能地优化它。

正如 Michael Szczesny 已经回答的那样,更 Pythonic 的方式是:

filtered_images=a[~np.isnan(a).any(axis=(2,1))]

如果那段代码难以理解,那么考虑用for循环提取每张图片如下:

filtered_images=list()
for value in a:
  if(np.isnan(value).any()!=True):
    filtered_images.append(value)

两种方法的输出应该相似!