获取给定批处理值的字典键 - Python
Get dictionary keys for given batched values - Python
我定义了一个字典 A
并且想找到给定一批值的键 a
:
def dictionary(r):
return dict(enumerate(r))
def get_key(val, my_dict):
for key, value in my_dict.items():
if np.array_equal(val,value):
return key
# dictionary
A = jnp.array([[0, 0],[1,1],[2,2],[3,3]])
A = dictionary(A)
a = jnp.array([[[1, 1],[2, 2], [3,3]],[[0, 0],[3, 3], [2,2]]])
keys = jax.vmap(jax.vmap(get_key, in_axes=(0,None)), in_axes=(0,None))(a, A)
预期的输出应该是:
keys = [[1,2,3],[0,3,2]]
为什么我得到 None
作为输出?
JAX 转换如 vmap
通过 tracing 函数工作,这意味着它们用值的抽象表示替换值以提取编码在函数(有关此概念的详细介绍,请参阅 How to think in JAX)。
这意味着要正确使用 vmap
,函数只能使用 JAX 方法,不能使用 numpy 方法,因此您使用 np.array_equal
会破坏抽象。
不幸的是,它实际上没有任何替代品,因为没有在具体 Python 字典中查找抽象 JAX 值的机制。如果你想对 JAX 值进行字典查找,你应该避免转换,只使用 Python 循环:
keys = jnp.array([[get_key(x, A) for x in row] for row in a])
另一方面,我怀疑这更像是一个 XY problem;您的目标不是在 jax 转换中查找字典值,而是解决某些问题。也许你应该问一个关于如何解决问题的问题,而不是如何使用你尝试过的解决方案来解决问题。
但是如果您不想直接使用字典,与 JAX 兼容的替代 get_key
实现可能如下所示:
def get_key(val, my_dict):
keys = jnp.array(list(my_dict.keys()))
values = jnp.array(list(my_dict.values()))
return keys[jnp.where((values == val).all(-1), size=1)]
我定义了一个字典 A
并且想找到给定一批值的键 a
:
def dictionary(r):
return dict(enumerate(r))
def get_key(val, my_dict):
for key, value in my_dict.items():
if np.array_equal(val,value):
return key
# dictionary
A = jnp.array([[0, 0],[1,1],[2,2],[3,3]])
A = dictionary(A)
a = jnp.array([[[1, 1],[2, 2], [3,3]],[[0, 0],[3, 3], [2,2]]])
keys = jax.vmap(jax.vmap(get_key, in_axes=(0,None)), in_axes=(0,None))(a, A)
预期的输出应该是:
keys = [[1,2,3],[0,3,2]]
为什么我得到 None
作为输出?
JAX 转换如 vmap
通过 tracing 函数工作,这意味着它们用值的抽象表示替换值以提取编码在函数(有关此概念的详细介绍,请参阅 How to think in JAX)。
这意味着要正确使用 vmap
,函数只能使用 JAX 方法,不能使用 numpy 方法,因此您使用 np.array_equal
会破坏抽象。
不幸的是,它实际上没有任何替代品,因为没有在具体 Python 字典中查找抽象 JAX 值的机制。如果你想对 JAX 值进行字典查找,你应该避免转换,只使用 Python 循环:
keys = jnp.array([[get_key(x, A) for x in row] for row in a])
另一方面,我怀疑这更像是一个 XY problem;您的目标不是在 jax 转换中查找字典值,而是解决某些问题。也许你应该问一个关于如何解决问题的问题,而不是如何使用你尝试过的解决方案来解决问题。
但是如果您不想直接使用字典,与 JAX 兼容的替代 get_key
实现可能如下所示:
def get_key(val, my_dict):
keys = jnp.array(list(my_dict.keys()))
values = jnp.array(list(my_dict.values()))
return keys[jnp.where((values == val).all(-1), size=1)]