我如何利用 JAX 库来加速我的代码?

How can I utilize JAX library to speed up my code?

我写了一段代码,获取一些顶点并根据一些规则重新排列它们。当输入包含大数据时,代码 运行 非常缓慢,例如对于 60000 个循环,google colab TPU 运行 时间大约需要 15 个小时。我发现 JAX 是这样做并尝试使用它的最佳库之一,但由于缺乏处理此类大数据及其相关方法(如并行化)的经验,我遇到了一些问题。创建以下小示例以显示代码的作用:

import numpy as np

# <class 'numpy.ma.core.MaskedArray'> <class 'numpy.ma.core.MaskedArray'> (m, 4) <class 'numpy.int64'>
nodes = np.ma.masked_array(np.array([[0, 1, 2, 3], [4, 0, 5, 1], [6, 4, 7, 5], [8, 6, 9, 7]],
                                    dtype=np.int64), mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

# <class 'numpy.ndarray'> <class 'numpy.ndarray'> (n, 3) <class 'numpy.float64'>
vert = np.array([[0.06944111, -0.12027553, -0.3], [0., -0.13888221, -0.3], [0.05, -0.08660254, -0.3],
                [0.06944111, -0.12027553, -0.5], [0.06944111, -0.12027553, -0.1], [0., -0.13888221, -0.1],
                [0.06944111, -0.12027553,  0.1], [0., -0.13888221,  0.1], [0.06944111, -0.12027553,  0.3],
                [0., -0.13888221,  0.3]])


def ali_sh():
    mod_array = []
    mod_idx = []

    for cell in range(len(nodes)):
        vertex_idx = []

        B_face = sorted(nodes[cell], key=lambda v: [vert[v][0]], reverse=True)
        if round(vert[B_face[1]][0], 7) == round(vert[B_face[2]][0], 7):
            if vert[B_face[1]][1] > vert[B_face[2]][1]:
                B_face[1], B_face[2] = B_face[2], B_face[1]

        mod_array.append(B_face)

        for vertex in B_face:
            vertex_idx.append(np.where(nodes[cell] == vertex)[0][0])

        mod_idx.append(vertex_idx)

    return mod_idx

mod_idx = ali_sh()

以上代码只是我的代码的一个视图,并且有一些差异,例如在此代码中 jnp.where 运行 正确但使用 the main code and the big data 它将卡住并且必须改用 np.where 。在我的第一次尝试中,我在代码末尾添加了 jax_r = jit(ali_sh)mod_idx = jax_r().block_until_ready(),但我没有获得更好的性能。我使用了 FiPy 库及其方法,其中在 numpy 类型中,例如'fipy.mesh.vertexCoords.T' 是一个 numpy ndarray。我试图通过 jnp.array(fipy numpy arrays) 将使用过的 fipy numpy 数组转换为 JAX 数组以检查它是否有帮助,但是由于通过 sorted 命令使用 lambda 我得到了错误。如何在我的代码上实现 JAX 以获得更好的 运行 时间。

colab 是否需要做任何事情才能在 TPU 或 GPU 上获得此类代码的最大能力? 使用 JAX 会对我的代码加速产生重大影响吗?如果有人可以帮助找出如何加速代码,我将不胜感激。

编写高效的 JAX 代码与编写高效的 NumPy 代码非常相似:通常,如果您对数据行使用 for 循环,您的代码将不会非常高效。相反,您应该努力根据向量化操作来编写计算。

在您的代码中,您似乎依赖于许多非 JAX 元素(例如 NumPy 掩码数组、FiPy 中的操作等),因此 JAX 不太可能改进您的运行时间。相反,我会专注于重写您的代码以有效利用 NumPy,将 for 循环逻辑替换为 NumPy 向量化操作。

下面是一个用向量化运算表达函数的例子:

def ali_sh_vectorized():
  i_sort = np.argsort(vert[nodes, 0], axis=1)[:, ::-1]
  B_face = nodes[np.arange(nodes.shape[0])[:, None], i_sort]
  close = np.isclose(vert[B_face[:, 1],1], vert[B_face[:, 2], 2])
  larger = np.greater(vert[B_face[:, 1],1], vert[B_face[:, 2], 2])
  col_1 = np.where(close & larger, B_face[:, 2], B_face[:, 1])
  col_2 = np.where(close & larger, B_face[:, 1], B_face[:, 2])
  B_face[:, 1] = col_1
  B_face[:, 2] = col_2
  mod_idx = np.where(nodes[:, :, None] == B_face[:, None, :])[2].reshape(nodes.shape)
  return mod_idx

输出与原始函数的差异是由于 Python 排序和 NumPy 排序处理等效元素的方式不同,但我相信整体逻辑是相同的。