JAX 中的累积
Accumulation in JAX
在 JAX 中编译累积时处理内存的最佳方法是什么,例如 jax.lax.scan
,满缓冲区过多?
以下是等比级数的例子。
诱惑是认识到累加仅取决于输入大小并相应地实施
import jax.numpy as jnp
import jax.lax as lax
def calc_gp_size(size,x0,a):
scan_fun = lambda carry, i : (a*carry,)*2
xn, x = lax.scan(scan_fun,x0,None,length=size-1)
return jnp.concatenate((x0[None],x))
jax.config.update("jax_enable_x64", True)
size = jnp.array(2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')
jax.jit(calc_gp_size)(size,x0,a)
但是,尝试使用 jax.jit
将可预见地导致 ConcretizationTypeError
。
正确的方法是在缓冲区已经存在的地方传递一个参数。
def calc_gp_array(array,x0,a):
scan_fun = lambda carry, i : (a*carry,)*2
xn, x = lax.scan(scan_fun,x0,array)
return jnp.concatenate((x0[None],x))
array = jnp.arange(1,2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')
jax.jit(calc_gp_array)(array,x0,a)
我担心的是有很多分配的内存没有被使用(或者是吗?)。此示例是否有更高效的内存方法,或者是否以某种方式使用了分配的内存?
编辑:合并@jakevdp的注释,将函数视为主要函数(单次调用——包括编译和排除缓存),分析结果
%memit jx.jit(calc_gp_size, static_argnums=0)(size,x0,a).block_until_ready()
# peak memory: 7058.32 MiB, increment: 959.94 MiB
%memit jx.jit(calc_gp_array)(jnp.arange(1,size,dtype='u8'),x0,a).block_until_ready()
peak memory: 7850.83 MiB, increment: 1240.22 MiB
%memit jnp.cumprod(jnp.full(size, a, dtype='f8').at[0].set(x0))
peak memory: 8150.05 MiB, increment: 1539.70 MiB
不太精细的结果需要对 jit 代码进行行分析(不确定如何完成)。
顺序初始化数组然后调用jax.jit
似乎可以节省内存
%memit array = jnp.arange(1,size,dtype='u8'); jx.jit(calc_gp_array)(array,x0,a).block_until_ready()
# peak memory: 6711.81 MiB, increment: 613.44 MiB
%memit array = jnp.full(size, a, dtype='f8').at[0].set(x0); jnp.cumprod(array)
# peak memory: 7675.15 MiB, increment: 1064.08 MiB
如果您将大小参数标记为静态并传递可散列值,第一个版本将起作用:
import jax
import jax.numpy as jnp
import jax.lax as lax
def calc_gp_size(size,x0,a):
scan_fun = lambda carry, i : (a*carry,)*2
xn, x = lax.scan(scan_fun,x0,None,length=size-1)
return jnp.concatenate((x0[None],x))
jax.config.update("jax_enable_x64", True)
size = 2 ** 26
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')
jax.jit(calc_gp_size, static_argnums=0)(size,x0,a)
# DeviceArray([1. , 1.00000001, 1.00000002, ..., 1.95636587,
# 1.95636589, 1.95636591], dtype=float64)
我认为这可能比 pre-allocating 第二个示例中的数组的内存效率略高,但如果这很重要,则值得进行基准测试。
此外,如果您在 GPU 上执行此类操作,您可能会发现 built-in 像 jnp.cumprod
这样的累加性能更高。我相信这或多或少等同于您的 scan-based 函数:
result = jnp.cumprod(jnp.full(size, 1 + 1E-8, dtype='f8').at[0].set(1))
在 JAX 中编译累积时处理内存的最佳方法是什么,例如 jax.lax.scan
,满缓冲区过多?
以下是等比级数的例子。 诱惑是认识到累加仅取决于输入大小并相应地实施
import jax.numpy as jnp
import jax.lax as lax
def calc_gp_size(size,x0,a):
scan_fun = lambda carry, i : (a*carry,)*2
xn, x = lax.scan(scan_fun,x0,None,length=size-1)
return jnp.concatenate((x0[None],x))
jax.config.update("jax_enable_x64", True)
size = jnp.array(2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')
jax.jit(calc_gp_size)(size,x0,a)
但是,尝试使用 jax.jit
将可预见地导致 ConcretizationTypeError
。
正确的方法是在缓冲区已经存在的地方传递一个参数。
def calc_gp_array(array,x0,a):
scan_fun = lambda carry, i : (a*carry,)*2
xn, x = lax.scan(scan_fun,x0,array)
return jnp.concatenate((x0[None],x))
array = jnp.arange(1,2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')
jax.jit(calc_gp_array)(array,x0,a)
我担心的是有很多分配的内存没有被使用(或者是吗?)。此示例是否有更高效的内存方法,或者是否以某种方式使用了分配的内存?
编辑:合并@jakevdp的注释,将函数视为主要函数(单次调用——包括编译和排除缓存),分析结果
%memit jx.jit(calc_gp_size, static_argnums=0)(size,x0,a).block_until_ready()
# peak memory: 7058.32 MiB, increment: 959.94 MiB
%memit jx.jit(calc_gp_array)(jnp.arange(1,size,dtype='u8'),x0,a).block_until_ready()
peak memory: 7850.83 MiB, increment: 1240.22 MiB
%memit jnp.cumprod(jnp.full(size, a, dtype='f8').at[0].set(x0))
peak memory: 8150.05 MiB, increment: 1539.70 MiB
不太精细的结果需要对 jit 代码进行行分析(不确定如何完成)。
顺序初始化数组然后调用jax.jit
似乎可以节省内存
%memit array = jnp.arange(1,size,dtype='u8'); jx.jit(calc_gp_array)(array,x0,a).block_until_ready()
# peak memory: 6711.81 MiB, increment: 613.44 MiB
%memit array = jnp.full(size, a, dtype='f8').at[0].set(x0); jnp.cumprod(array)
# peak memory: 7675.15 MiB, increment: 1064.08 MiB
如果您将大小参数标记为静态并传递可散列值,第一个版本将起作用:
import jax
import jax.numpy as jnp
import jax.lax as lax
def calc_gp_size(size,x0,a):
scan_fun = lambda carry, i : (a*carry,)*2
xn, x = lax.scan(scan_fun,x0,None,length=size-1)
return jnp.concatenate((x0[None],x))
jax.config.update("jax_enable_x64", True)
size = 2 ** 26
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')
jax.jit(calc_gp_size, static_argnums=0)(size,x0,a)
# DeviceArray([1. , 1.00000001, 1.00000002, ..., 1.95636587,
# 1.95636589, 1.95636591], dtype=float64)
我认为这可能比 pre-allocating 第二个示例中的数组的内存效率略高,但如果这很重要,则值得进行基准测试。
此外,如果您在 GPU 上执行此类操作,您可能会发现 built-in 像 jnp.cumprod
这样的累加性能更高。我相信这或多或少等同于您的 scan-based 函数:
result = jnp.cumprod(jnp.full(size, 1 + 1E-8, dtype='f8').at[0].set(1))