在 JAX 中计算词向量移动平均值的最佳方法
Best way to compute the moving average of word vectors in JAX
假设我有一个形状为 (n_words, model_dim)
的矩阵 W
,其中 n_words
是句子中的单词数,model_dim
是 space表示词向量的地方。计算这些向量的移动平均值的最快方法是什么?
例如,window 大小为 2(window 长度 = 5),我可以有这样的东西(这会引发错误 TypeError: JAX 'Tracer' objects do not support item assignment
):
from jax import random
import jax.numpy as jnp
# Fake word vectors (17 words vectors of dimension 32)
W = random.normal(random.PRNGKey(0), shape=(17, 32))
ws = 2 # window size
N = W.shape[0] # number of words
new_W = jnp.zeros(W.shape)
for i in range(N):
window = W[max(0, i-ws):min(N, i+ws+1)]
n = window.shape[0]
for j in range(n):
new_W[i] += W[j] / n
我想 jnp.convolve
有一个更快的解决方案,但我不熟悉它。
您似乎在尝试进行卷积,因此 jnp.convolve
或类似方法可能是一种更高效的方法。
就是说,您的示例有点奇怪,因为 n
永远不会大于 4,因此您只能访问 W
的前四个元素。此外,您在内循环的每次迭代中都会覆盖先前的值,因此 new_W
的每一行仅包含 W
.
的前四行之一的缩放副本
将您的代码更改为我认为您的意思并使用 index_update 使其与 JAX 的不可变数组兼容给出了这个:
from jax import random
import jax.numpy as jnp
# Fake word vectors (17 words vectors of dimension 32)
W = random.normal(random.PRNGKey(0), shape=(17, 32))
ws = 2 # window size
N = W.shape[0] # number of words
new_W = jnp.zeros(W.shape)
for i in range(N):
window = W[max(0, i-ws):min(N, i+ws)]
n = window.shape[0]
for j in range(n):
new_W = new_W.at[i].add(window[j] / n)
这里是一个更高效的卷积的等价物:
from jax.scipy.signal import convolve
kernel = jnp.ones((4, 1))
new_W_2 = convolve(W, kernel, mode='same') / convolve(jnp.ones_like(W), kernel, mode='same')
jnp.allclose(new_W, new_W_2)
# True
假设我有一个形状为 (n_words, model_dim)
的矩阵 W
,其中 n_words
是句子中的单词数,model_dim
是 space表示词向量的地方。计算这些向量的移动平均值的最快方法是什么?
例如,window 大小为 2(window 长度 = 5),我可以有这样的东西(这会引发错误 TypeError: JAX 'Tracer' objects do not support item assignment
):
from jax import random
import jax.numpy as jnp
# Fake word vectors (17 words vectors of dimension 32)
W = random.normal(random.PRNGKey(0), shape=(17, 32))
ws = 2 # window size
N = W.shape[0] # number of words
new_W = jnp.zeros(W.shape)
for i in range(N):
window = W[max(0, i-ws):min(N, i+ws+1)]
n = window.shape[0]
for j in range(n):
new_W[i] += W[j] / n
我想 jnp.convolve
有一个更快的解决方案,但我不熟悉它。
您似乎在尝试进行卷积,因此 jnp.convolve
或类似方法可能是一种更高效的方法。
就是说,您的示例有点奇怪,因为 n
永远不会大于 4,因此您只能访问 W
的前四个元素。此外,您在内循环的每次迭代中都会覆盖先前的值,因此 new_W
的每一行仅包含 W
.
将您的代码更改为我认为您的意思并使用 index_update 使其与 JAX 的不可变数组兼容给出了这个:
from jax import random
import jax.numpy as jnp
# Fake word vectors (17 words vectors of dimension 32)
W = random.normal(random.PRNGKey(0), shape=(17, 32))
ws = 2 # window size
N = W.shape[0] # number of words
new_W = jnp.zeros(W.shape)
for i in range(N):
window = W[max(0, i-ws):min(N, i+ws)]
n = window.shape[0]
for j in range(n):
new_W = new_W.at[i].add(window[j] / n)
这里是一个更高效的卷积的等价物:
from jax.scipy.signal import convolve
kernel = jnp.ones((4, 1))
new_W_2 = convolve(W, kernel, mode='same') / convolve(jnp.ones_like(W), kernel, mode='same')
jnp.allclose(new_W, new_W_2)
# True