如何将这个嵌套的 for 循环写成列表理解?

How do I write this nested for loop as a list comprehension?

我正在处理 4D 数据集,其中有一个嵌套的 for 循环(4 个循环)。 for 循环有效,但 运行 需要一段时间:~5 分钟。我试图用列表理解来正确地写这个,但是我对如何在嵌套循环中做到这一点感到困惑:

data = np.random.rand(12, 27, 282, 375)

stdev_data = np.std(data, axis=1)

## nested for loop 

count = []

for i in range(data.shape[0]):
    for j in range(data.shape[1]):
        for lat in range(data.shape[2]):
            for lon in range(data.shape[3]):
                count.append((data[i, j, lat, lon] < -1.282 * stdev_data[i, lat, lon]).sum(axis=0))

reshape_counts = np.reshape(count, data.shape)

这是我对列表理解的尝试:

i, j, lat, lon = data.shape[0], data.shape[1], data.shape[2], data.shape[3]
print(i, j, lat, lon)

test_list = [[(data < -1.282 * stdev_data).sum(axis=0) for lon in lat] for j in i]

我收到一条错误消息,提示 'int' 对象不可迭代。如何以列表理解的形式重写我的嵌套 for 循环以加快进程?

除了将 False 转换为 0 并将 True 转换为 1 之外,我认为总和不会做任何事情,因为您只是将两个数字相互比较。我认为这会做同样的事情(我找不到摆脱最后一个循环的方法,如果你真的需要它更快,也许 joblibnumba 会有所帮助,但我没有经常使用它们,所以不确定):

count = np.empty(data.shape)
for j in range(data.shape[1]):
    count[:,j,...] = (data[:,j,...] < -1.282*stdev_data).astype(np.int32)

但是标准偏差也不能为负,所以没有什么能满足上述条件,因为你乘以一个负数,但你所有的数据都在 0 和 1 之间,所以我建议仔细检查所有内容

鉴于您使用的是 numpy,我建议您利用它们的 for 循环是用 C 编写的,并且经常进行优化这一事实。您最终仍将单步执行数据,但速度要快得多。这种方法称为矢量化。

在这种情况下,您试图制作一个布尔掩码,这可以简化操作。请记住,表达式中的 .sum() 调用是一个转移注意力的问题:您实际上是在对标量布尔值求和,它总是给您零或一。

以下是如何在第二维中找到小于 -1.282 的 sigma 的点:

result = data < -1.282 * stdev_data[:, None, ...]

或者,您可以这样做

result = data < -1.282 * stdev_data.reshape(stdev_data.shape[0], 1, *stdev_data.shape[1:])

result = data < -1.282 * np.reshape(stdev_data, stdev_data.shape[:1] + (1,) + stdev_data.shape[1:])

一个更简单的解决方案是从一开始就将 keepdims=True 传递给 np.std

result = data < -1.282 * np.std(data, axis=1, keepdims=True)

keepdims=True 确保 std 的输出具有形状 (12, 1, 282, 375) 而不是 (12, 282, 375),因此您不需要自己重新插入维度.

现在,如果您真的想像您的问题所暗示的那样计算计数,您可以沿着第二个维度对 result 掩码求和:

counts = result.sum(axis=1)

最后,完全按照说明回答您的实际问题:for 循环直接转化为列表理解。在你的例子中,这意味着理解中有四个 fors,完全按照你最初拥有它们的顺序:

[data[i, j, lat, lon] < -1.282 * stdev_data[i, lat, lon]
    for i in range(data.shape[0])
        for j in range(data.shape[1])
            for lat in range(data.shape[2])
                for lon in range(data.shape[3])]

由于推导式被括号括起来,您可以像我一样将它们的内容自由地写在不同的行上,尽管这当然不是必需的。请注意,唯一真正的区别是 append 的内容排在第一位并且没有冒号。另外,那个红鲱鱼 sum 不见了。