使用 jax 计算按行(或按轴)点积的最佳方法是什么?
What's the best way to compute row-wise (or axis-wise) dot products with jax?
我有两个形状为 (N, M) 的数值数组。我想计算逐行点积。 IE。生成一个形状为 (N,) 的数组,使得第 n 行是每个数组中第 n 行的点积。
我知道 numpy 的 inner1d
方法。使用 jax 执行此操作的最佳方法是什么? jax 有 jax.numpy.inner
,但这还有其他作用。
你可以试试jax.numpy.einsum。这里使用 numpy einsum
实现
import numpy as np
from numpy.core.umath_tests import inner1d
arr1 = np.random.randint(0,10,[5,5])
arr2 = np.random.randint(0,10,[5,5])
arr = np.inner1d(arr1,arr2)
arr
array([ 87, 200, 229, 81, 53])
np.einsum('...i,...i->...',arr1,arr2)
array([ 87, 200, 229, 81, 53])
您可以在几行 jax 代码中定义自己的 inner1d
jit 编译版本:
import jax
@jax.jit
def inner1d(X, Y):
return (X * Y).sum(-1)
正在测试:
import jax.numpy as jnp
import numpy as np
from numpy.core import umath_tests
X = np.random.rand(5, 10)
Y = np.random.rand(5, 10)
print(umath_tests.inner1d(X, Y))
print(inner1d(jnp.array(X), jnp.array(Y)))
# [2.23219571 2.1013316 2.70353783 2.14094973 2.62582531]
# [2.2321959 2.1013315 2.703538 2.1409497 2.6258256]
我有两个形状为 (N, M) 的数值数组。我想计算逐行点积。 IE。生成一个形状为 (N,) 的数组,使得第 n 行是每个数组中第 n 行的点积。
我知道 numpy 的 inner1d
方法。使用 jax 执行此操作的最佳方法是什么? jax 有 jax.numpy.inner
,但这还有其他作用。
你可以试试jax.numpy.einsum。这里使用 numpy einsum
实现import numpy as np
from numpy.core.umath_tests import inner1d
arr1 = np.random.randint(0,10,[5,5])
arr2 = np.random.randint(0,10,[5,5])
arr = np.inner1d(arr1,arr2)
arr
array([ 87, 200, 229, 81, 53])
np.einsum('...i,...i->...',arr1,arr2)
array([ 87, 200, 229, 81, 53])
您可以在几行 jax 代码中定义自己的 inner1d
jit 编译版本:
import jax
@jax.jit
def inner1d(X, Y):
return (X * Y).sum(-1)
正在测试:
import jax.numpy as jnp
import numpy as np
from numpy.core import umath_tests
X = np.random.rand(5, 10)
Y = np.random.rand(5, 10)
print(umath_tests.inner1d(X, Y))
print(inner1d(jnp.array(X), jnp.array(Y)))
# [2.23219571 2.1013316 2.70353783 2.14094973 2.62582531]
# [2.2321959 2.1013315 2.703538 2.1409497 2.6258256]