我如何利用 JAX 库来加速我的代码?
How can I utilize JAX library to speed up my code?
我写了一段代码,获取一些顶点并根据一些规则重新排列它们。当输入包含大数据时,代码 运行 非常缓慢,例如对于 60000 个循环,google colab TPU 运行 时间大约需要 15 个小时。我发现 JAX 是这样做并尝试使用它的最佳库之一,但由于缺乏处理此类大数据及其相关方法(如并行化)的经验,我遇到了一些问题。创建以下小示例以显示代码的作用:
import numpy as np
# <class 'numpy.ma.core.MaskedArray'> <class 'numpy.ma.core.MaskedArray'> (m, 4) <class 'numpy.int64'>
nodes = np.ma.masked_array(np.array([[0, 1, 2, 3], [4, 0, 5, 1], [6, 4, 7, 5], [8, 6, 9, 7]],
dtype=np.int64), mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
# <class 'numpy.ndarray'> <class 'numpy.ndarray'> (n, 3) <class 'numpy.float64'>
vert = np.array([[0.06944111, -0.12027553, -0.3], [0., -0.13888221, -0.3], [0.05, -0.08660254, -0.3],
[0.06944111, -0.12027553, -0.5], [0.06944111, -0.12027553, -0.1], [0., -0.13888221, -0.1],
[0.06944111, -0.12027553, 0.1], [0., -0.13888221, 0.1], [0.06944111, -0.12027553, 0.3],
[0., -0.13888221, 0.3]])
def ali_sh():
mod_array = []
mod_idx = []
for cell in range(len(nodes)):
vertex_idx = []
B_face = sorted(nodes[cell], key=lambda v: [vert[v][0]], reverse=True)
if round(vert[B_face[1]][0], 7) == round(vert[B_face[2]][0], 7):
if vert[B_face[1]][1] > vert[B_face[2]][1]:
B_face[1], B_face[2] = B_face[2], B_face[1]
mod_array.append(B_face)
for vertex in B_face:
vertex_idx.append(np.where(nodes[cell] == vertex)[0][0])
mod_idx.append(vertex_idx)
return mod_idx
mod_idx = ali_sh()
以上代码只是我的代码的一个视图,并且有一些差异,例如在此代码中 jnp.where
运行 正确但使用 the main code and the big data 它将卡住并且必须改用 np.where
。在我的第一次尝试中,我在代码末尾添加了 jax_r = jit(ali_sh)
和 mod_idx = jax_r().block_until_ready()
,但我没有获得更好的性能。我使用了 FiPy 库及其方法,其中在 numpy 类型中,例如'fipy.mesh.vertexCoords.T' 是一个 numpy ndarray。我试图通过 jnp.array(fipy numpy arrays)
将使用过的 fipy numpy 数组转换为 JAX 数组以检查它是否有帮助,但是由于通过 sorted
命令使用 lambda
我得到了错误。如何在我的代码上实现 JAX 以获得更好的 运行 时间。
colab 是否需要做任何事情才能在 TPU 或 GPU 上获得此类代码的最大能力?
使用 JAX 会对我的代码加速产生重大影响吗?如果有人可以帮助找出如何加速代码,我将不胜感激。
编写高效的 JAX 代码与编写高效的 NumPy 代码非常相似:通常,如果您对数据行使用 for
循环,您的代码将不会非常高效。相反,您应该努力根据向量化操作来编写计算。
在您的代码中,您似乎依赖于许多非 JAX 元素(例如 NumPy 掩码数组、FiPy 中的操作等),因此 JAX 不太可能改进您的运行时间。相反,我会专注于重写您的代码以有效利用 NumPy,将 for
循环逻辑替换为 NumPy 向量化操作。
下面是一个用向量化运算表达函数的例子:
def ali_sh_vectorized():
i_sort = np.argsort(vert[nodes, 0], axis=1)[:, ::-1]
B_face = nodes[np.arange(nodes.shape[0])[:, None], i_sort]
close = np.isclose(vert[B_face[:, 1],1], vert[B_face[:, 2], 2])
larger = np.greater(vert[B_face[:, 1],1], vert[B_face[:, 2], 2])
col_1 = np.where(close & larger, B_face[:, 2], B_face[:, 1])
col_2 = np.where(close & larger, B_face[:, 1], B_face[:, 2])
B_face[:, 1] = col_1
B_face[:, 2] = col_2
mod_idx = np.where(nodes[:, :, None] == B_face[:, None, :])[2].reshape(nodes.shape)
return mod_idx
输出与原始函数的差异是由于 Python 排序和 NumPy 排序处理等效元素的方式不同,但我相信整体逻辑是相同的。
我写了一段代码,获取一些顶点并根据一些规则重新排列它们。当输入包含大数据时,代码 运行 非常缓慢,例如对于 60000 个循环,google colab TPU 运行 时间大约需要 15 个小时。我发现 JAX 是这样做并尝试使用它的最佳库之一,但由于缺乏处理此类大数据及其相关方法(如并行化)的经验,我遇到了一些问题。创建以下小示例以显示代码的作用:
import numpy as np
# <class 'numpy.ma.core.MaskedArray'> <class 'numpy.ma.core.MaskedArray'> (m, 4) <class 'numpy.int64'>
nodes = np.ma.masked_array(np.array([[0, 1, 2, 3], [4, 0, 5, 1], [6, 4, 7, 5], [8, 6, 9, 7]],
dtype=np.int64), mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
# <class 'numpy.ndarray'> <class 'numpy.ndarray'> (n, 3) <class 'numpy.float64'>
vert = np.array([[0.06944111, -0.12027553, -0.3], [0., -0.13888221, -0.3], [0.05, -0.08660254, -0.3],
[0.06944111, -0.12027553, -0.5], [0.06944111, -0.12027553, -0.1], [0., -0.13888221, -0.1],
[0.06944111, -0.12027553, 0.1], [0., -0.13888221, 0.1], [0.06944111, -0.12027553, 0.3],
[0., -0.13888221, 0.3]])
def ali_sh():
mod_array = []
mod_idx = []
for cell in range(len(nodes)):
vertex_idx = []
B_face = sorted(nodes[cell], key=lambda v: [vert[v][0]], reverse=True)
if round(vert[B_face[1]][0], 7) == round(vert[B_face[2]][0], 7):
if vert[B_face[1]][1] > vert[B_face[2]][1]:
B_face[1], B_face[2] = B_face[2], B_face[1]
mod_array.append(B_face)
for vertex in B_face:
vertex_idx.append(np.where(nodes[cell] == vertex)[0][0])
mod_idx.append(vertex_idx)
return mod_idx
mod_idx = ali_sh()
以上代码只是我的代码的一个视图,并且有一些差异,例如在此代码中 jnp.where
运行 正确但使用 the main code and the big data 它将卡住并且必须改用 np.where
。在我的第一次尝试中,我在代码末尾添加了 jax_r = jit(ali_sh)
和 mod_idx = jax_r().block_until_ready()
,但我没有获得更好的性能。我使用了 FiPy 库及其方法,其中在 numpy 类型中,例如'fipy.mesh.vertexCoords.T' 是一个 numpy ndarray。我试图通过 jnp.array(fipy numpy arrays)
将使用过的 fipy numpy 数组转换为 JAX 数组以检查它是否有帮助,但是由于通过 sorted
命令使用 lambda
我得到了错误。如何在我的代码上实现 JAX 以获得更好的 运行 时间。
colab 是否需要做任何事情才能在 TPU 或 GPU 上获得此类代码的最大能力? 使用 JAX 会对我的代码加速产生重大影响吗?如果有人可以帮助找出如何加速代码,我将不胜感激。
编写高效的 JAX 代码与编写高效的 NumPy 代码非常相似:通常,如果您对数据行使用 for
循环,您的代码将不会非常高效。相反,您应该努力根据向量化操作来编写计算。
在您的代码中,您似乎依赖于许多非 JAX 元素(例如 NumPy 掩码数组、FiPy 中的操作等),因此 JAX 不太可能改进您的运行时间。相反,我会专注于重写您的代码以有效利用 NumPy,将 for
循环逻辑替换为 NumPy 向量化操作。
下面是一个用向量化运算表达函数的例子:
def ali_sh_vectorized():
i_sort = np.argsort(vert[nodes, 0], axis=1)[:, ::-1]
B_face = nodes[np.arange(nodes.shape[0])[:, None], i_sort]
close = np.isclose(vert[B_face[:, 1],1], vert[B_face[:, 2], 2])
larger = np.greater(vert[B_face[:, 1],1], vert[B_face[:, 2], 2])
col_1 = np.where(close & larger, B_face[:, 2], B_face[:, 1])
col_2 = np.where(close & larger, B_face[:, 1], B_face[:, 2])
B_face[:, 1] = col_1
B_face[:, 2] = col_2
mod_idx = np.where(nodes[:, :, None] == B_face[:, None, :])[2].reshape(nodes.shape)
return mod_idx
输出与原始函数的差异是由于 Python 排序和 NumPy 排序处理等效元素的方式不同,但我相信整体逻辑是相同的。