如何加速图像的不同大小和形状的最大池化集群?

How can I speed up max pooling clusters of different sizes and shapes of an image?

我已将图像的像素聚类成不同大小和形状的聚类。我想尽可能快地对每个集群进行最大池化,因为最大池化发生在我的 CNN 的一层中。

澄清一下: 输入是具有以下形状 [batch_size、图像高度、图像宽度、通道数] 的一批图像。在开始训练我的 CNN 之前,我已经对每张图片进行了聚类。因此,对于每张图像,我都有一个形状为 [图像高度,图像宽度] 的标签数组。

如何最大池化所有标签具有相同标签的图像的所有像素?我知道如何使用 for 循环来做到这一点,但这非常缓慢。我正在寻找一种快速解决方案,理想情况下可以在不到一秒的时间内对每个图像的每个集群进行最大池化。

为了实现,我使用 Python3.7 和 PyTorch。

我想通了。 torch_scatter。 scatter_max(img, cluster_labels) 从每个簇中输出最大元素并从我的代码中删除 for 循环。