在 jax 中嵌套 for 循环

Nested for loops in jax

我有一个处理图像的函数


# img : RGB image (512, 512, 3)
# kernel : 5x5 filtering kernel 
# dilation : can take integer values 1, 2, 4, ...
# some_data : no conditions are dependent on the value of this argument and it 
#            remains constant across multiple invocations this function
def filtering(img, kernel, dilation, some_data):
    h, w, _ = img.shape
    filtered_img = jnp.zeros(img.shape)

    radius = 2
    for i in range(radius, h-radius):
        for j in range(radius, w-radius):

            center_pos = np.array([i, j])
            sum = jnp.array([0.0, 0.0, 0.0])
            sum_w = 0.0
            
            for ii in range(-radius, radius + 1):
                for jj in range(-radius, radius + 1):

                    pos = center_pos + dilation * np.array([ii, jj])
                    
                    # if not for the `compute_weight` function this could have been a dilated convolution
                    weight = kernel[ii + radius, jj + radius] * compute_weight(center_pos, pos, some_data)
                    sum += img[pos[0], pos[1], :] * weight
                    sum_w += weight

            filtered_img = filtered_img.at[i, j].set(sum/sum_w)

    return filtered_img

第一个函数调用(jit 编译)大约需要。 6 小时到 运行(在 GPU 和 CPU 上都试过)。由于它是 jit 编译的,因此后续的 运行 可能会更快,但第一个 运行 非常昂贵。

我尝试删除 compute_weight 函数并将最里面的两个嵌套循环替换为 jnp.sum(img[i-radius:i+radius+1, j-radius:j+radius+1] * filter, axis=(0, 1)),但第一个函数调用仍然需要大约 30 分钟才能 运行。根据这一观察和 上的一些其他问题,这似乎是由于一般的 for loops

是否会以更实用的方式重写它并使用 loops 之类的 jax 构造,或者这是由于其他问题而发生的吗?

这里的问题不是执行时间,而是编译时间。 JAX 的 JIT 编译将展平所有 Python 控制流:这意味着对于您的输入,您正在为内部循环生成 512 * 512 * 5 * 5 个 jaxpr 副本,并将它们发送到 XLA 进行编译。由于编译成本大致与程序长度的平方成比例,因此结果将是一个非常长的编译。

你最好的选择可能是根据 jax.fori_loop 重写它,这将把循环逻辑直接降低到 XLA 而无需大量编译成本。

更好的是,因为看起来你正在做的是某种形式的卷积,所以可以用 jax.scipy.signal.convolve2d 之类的方式来表达它,这比手动循环要快得多.