Jax 中的 vmap ops.index_update
vmap ops.index_update in Jax
我有下面的代码,它使用了一个简单的 for 循环。我只是想知道是否有办法 vmap 它?原代码如下:
import numpy as np
import jax.numpy as jnp
import jax.scipy.signal as jscp
from scipy import signal
import jax
data = np.random.rand(192,334)
a = [1,-1.086740193996892,0.649914553946275,-0.124948974636730]
b = [0.054778173164082,0.164334519492245,0.164334519492245,0.054778173164082]
impulse = signal.lfilter(b, a, [1] + [0]*99)
impulse_20 = impulse[:20]
impulse_20 = jnp.asarray(impulse_20)
@jax.jit
def filter_jax(y):
for ind in range(0, len(y)):
y = jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])
return y
jnpData = jnp.asarray(data)
%timeit filter_jax(jnpData).block_until_ready()
这是我使用 vmap 的尝试:
def paraUpdate(y, ind):
return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])
@jax.jit
def filter_jax2(y):
ranger = range(0, len(y))
return jax.vmap(paraUpdate, y)(ranger)
但我收到以下错误:
TypeError: vmap in_axes must be an int, None, or (nested) container
with those types as leaves, but got
Traced<ShapedArray(float32[192,334])>with<DynamicJaxprTrace(level=0/1)>.
我有点困惑,因为范围是 int 类型,所以我不太确定发生了什么。
最后,我正在尝试将这一小块优化得尽可能好,以获得最少的时间。
jax.vmap
可以表示单个操作独立应用于输入的多个轴的功能。您的功能有点不同:您将一个操作迭代应用于单个输入。
幸运的是 JAX 提供了 lax.scan
可以处理这种情况。实现看起来像这样:
from jax import lax
def paraUpdate(y, ind):
return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19]), ind
@jax.jit
def filter_jax2(y):
ranger = jnp.arange(len(y))
return lax.scan(paraUpdate, y, ranger)[0]
print(np.allclose(filter_jax(jnpData), filter_jax2(jnpData)))
# True
%timeit filter_jax(jnpData).block_until_ready()
# 10 loops, best of 3: 28.6 ms per loop
%timeit filter_jax2(jnpData).block_until_ready()
# 1000 loops, best of 3: 519 µs per loop
如果您更改算法以便将操作应用于数组中的 每个 列而不是前 N 列,可以用vmap
这样表示:
@jax.jit
def filter_jax3(y):
f = lambda col: jscp.convolve(impulse_20, col)[:-19]
return jax.vmap(f, in_axes=1, out_axes=1)(y)
我有下面的代码,它使用了一个简单的 for 循环。我只是想知道是否有办法 vmap 它?原代码如下:
import numpy as np
import jax.numpy as jnp
import jax.scipy.signal as jscp
from scipy import signal
import jax
data = np.random.rand(192,334)
a = [1,-1.086740193996892,0.649914553946275,-0.124948974636730]
b = [0.054778173164082,0.164334519492245,0.164334519492245,0.054778173164082]
impulse = signal.lfilter(b, a, [1] + [0]*99)
impulse_20 = impulse[:20]
impulse_20 = jnp.asarray(impulse_20)
@jax.jit
def filter_jax(y):
for ind in range(0, len(y)):
y = jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])
return y
jnpData = jnp.asarray(data)
%timeit filter_jax(jnpData).block_until_ready()
这是我使用 vmap 的尝试:
def paraUpdate(y, ind):
return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])
@jax.jit
def filter_jax2(y):
ranger = range(0, len(y))
return jax.vmap(paraUpdate, y)(ranger)
但我收到以下错误:
TypeError: vmap in_axes must be an int, None, or (nested) container with those types as leaves, but got Traced<ShapedArray(float32[192,334])>with<DynamicJaxprTrace(level=0/1)>.
我有点困惑,因为范围是 int 类型,所以我不太确定发生了什么。
最后,我正在尝试将这一小块优化得尽可能好,以获得最少的时间。
jax.vmap
可以表示单个操作独立应用于输入的多个轴的功能。您的功能有点不同:您将一个操作迭代应用于单个输入。
幸运的是 JAX 提供了 lax.scan
可以处理这种情况。实现看起来像这样:
from jax import lax
def paraUpdate(y, ind):
return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19]), ind
@jax.jit
def filter_jax2(y):
ranger = jnp.arange(len(y))
return lax.scan(paraUpdate, y, ranger)[0]
print(np.allclose(filter_jax(jnpData), filter_jax2(jnpData)))
# True
%timeit filter_jax(jnpData).block_until_ready()
# 10 loops, best of 3: 28.6 ms per loop
%timeit filter_jax2(jnpData).block_until_ready()
# 1000 loops, best of 3: 519 µs per loop
如果您更改算法以便将操作应用于数组中的 每个 列而不是前 N 列,可以用vmap
这样表示:
@jax.jit
def filter_jax3(y):
f = lambda col: jscp.convolve(impulse_20, col)[:-19]
return jax.vmap(f, in_axes=1, out_axes=1)(y)