选择 JAX 矩阵子集的最快方法是什么?
What is the fastest way of selecting a subset of a JAX matrix?
假设我有一个二维矩阵,我想在直方图中绘制它的值。为此,我需要做类似的事情:
list_1d = matrix_2d.reshape((-1,)).tolist()
然后使用列表绘制直方图。到目前为止一切顺利,只是原始矩阵中有一些我想排除的项目。为简单起见,假设我有一个这样的列表:
exclude = [(2, 5), (3, 4), (6, 1)]
所以,list_1d
应该包含矩阵中的所有项目,不包括 exclude
指向的项目(exclude
的项目是行和列索引)。
顺便说一句,matrix_2d
是一个 JAX 数组,这意味着它的内容在 GPU 中。
一种方法是创建一个掩码数组,用于 select 所需的数组子集。掩码索引操作 returns selected 数据的一维副本:
import jax.numpy as jnp
from jax import random
matrix_2d = random.uniform(random.PRNGKey(0), (10, 10))
exclude = [(2, 5), (3, 4), (6, 1)]
ind = tuple(jnp.array(exclude).T)
mask = jnp.ones_like(matrix_2d, dtype=bool).at[ind].set(False)
list_1d = matrix_2d[mask].tolist()
len(list_1d)
# 97
假设我有一个二维矩阵,我想在直方图中绘制它的值。为此,我需要做类似的事情:
list_1d = matrix_2d.reshape((-1,)).tolist()
然后使用列表绘制直方图。到目前为止一切顺利,只是原始矩阵中有一些我想排除的项目。为简单起见,假设我有一个这样的列表:
exclude = [(2, 5), (3, 4), (6, 1)]
所以,list_1d
应该包含矩阵中的所有项目,不包括 exclude
指向的项目(exclude
的项目是行和列索引)。
顺便说一句,matrix_2d
是一个 JAX 数组,这意味着它的内容在 GPU 中。
一种方法是创建一个掩码数组,用于 select 所需的数组子集。掩码索引操作 returns selected 数据的一维副本:
import jax.numpy as jnp
from jax import random
matrix_2d = random.uniform(random.PRNGKey(0), (10, 10))
exclude = [(2, 5), (3, 4), (6, 1)]
ind = tuple(jnp.array(exclude).T)
mask = jnp.ones_like(matrix_2d, dtype=bool).at[ind].set(False)
list_1d = matrix_2d[mask].tolist()
len(list_1d)
# 97