JAX 中的累积

Accumulation in JAX

在 JAX 中编译累积时处理内存的最佳方法是什么,例如 jax.lax.scan,满缓冲区过多?

以下是等比级数的例子。 诱惑是认识到累加仅取决于输入大小并相应地实施

import jax.numpy as jnp
import jax.lax as lax

def calc_gp_size(size,x0,a):
    scan_fun = lambda carry, i : (a*carry,)*2
    xn, x = lax.scan(scan_fun,x0,None,length=size-1)
    return jnp.concatenate((x0[None],x))

jax.config.update("jax_enable_x64", True)

size = jnp.array(2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')

jax.jit(calc_gp_size)(size,x0,a)

但是,尝试使用 jax.jit 将可预见地导致 ConcretizationTypeError

正确的方法是在缓冲区已经存在的地方传递一个参数。

def calc_gp_array(array,x0,a):
    scan_fun = lambda carry, i : (a*carry,)*2
    xn, x = lax.scan(scan_fun,x0,array)
    return jnp.concatenate((x0[None],x))

array = jnp.arange(1,2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')  

jax.jit(calc_gp_array)(array,x0,a)

我担心的是有很多分配的内存没有被使用(或者是吗?)。此示例是否有更高效的内存方法,或者是否以某种方式使用了分配的内存?

编辑:合并@jakevdp的注释,将函数视为主要函数(单次调用——包括编译和排除缓存),分析结果

%memit jx.jit(calc_gp_size, static_argnums=0)(size,x0,a).block_until_ready()
# peak memory: 7058.32 MiB, increment: 959.94 MiB

%memit jx.jit(calc_gp_array)(jnp.arange(1,size,dtype='u8'),x0,a).block_until_ready()
peak memory: 7850.83 MiB, increment: 1240.22 MiB

%memit jnp.cumprod(jnp.full(size, a, dtype='f8').at[0].set(x0))
peak memory: 8150.05 MiB, increment: 1539.70 MiB

不太精细的结果需要对 jit 代码进行行分析(不确定如何完成)。

顺序初始化数组然后调用jax.jit似乎可以节省内存

%memit array = jnp.arange(1,size,dtype='u8'); jx.jit(calc_gp_array)(array,x0,a).block_until_ready()
# peak memory: 6711.81 MiB, increment: 613.44 MiB

%memit array = jnp.full(size, a, dtype='f8').at[0].set(x0); jnp.cumprod(array)
# peak memory: 7675.15 MiB, increment: 1064.08 MiB

如果您将大小参数标记为静态并传递可散列值,第一个版本将起作用:

import jax
import jax.numpy as jnp
import jax.lax as lax

def calc_gp_size(size,x0,a):
    scan_fun = lambda carry, i : (a*carry,)*2
    xn, x = lax.scan(scan_fun,x0,None,length=size-1)
    return jnp.concatenate((x0[None],x))

jax.config.update("jax_enable_x64", True)

size = 2 ** 26
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')

jax.jit(calc_gp_size, static_argnums=0)(size,x0,a)
# DeviceArray([1.        , 1.00000001, 1.00000002, ..., 1.95636587,
#              1.95636589, 1.95636591], dtype=float64)

我认为这可能比 pre-allocating 第二个示例中的数组的内存效率略高,但如果这很重要,则值得进行基准测试。

此外,如果您在 GPU 上执行此类操作,您可能会发现 built-in 像 jnp.cumprod 这样的累加性能更高。我相信这或多或少等同于您的 scan-based 函数:

result = jnp.cumprod(jnp.full(size, 1 + 1E-8, dtype='f8').at[0].set(1))