获取给定批处理值的字典键 - Python

Get dictionary keys for given batched values - Python

我定义了一个字典 A 并且想找到给定一批值的键 a:

def dictionary(r):
 return dict(enumerate(r))

def get_key(val, my_dict):
   for key, value in my_dict.items():
      if np.array_equal(val,value):
          return key
    

 # dictionary
 A = jnp.array([[0, 0],[1,1],[2,2],[3,3]])
 A = dictionary(A)

 a = jnp.array([[[1, 1],[2, 2], [3,3]],[[0, 0],[3, 3], [2,2]]])
 keys = jax.vmap(jax.vmap(get_key, in_axes=(0,None)), in_axes=(0,None))(a, A)

预期的输出应该是: keys = [[1,2,3],[0,3,2]]

为什么我得到 None 作为输出?

JAX 转换如 vmap 通过 tracing 函数工作,这意味着它们用值的抽象表示替换值以提取编码在函数(有关此概念的详细介绍,请参阅 How to think in JAX)。

这意味着要正确使用 vmap,函数只能使用 JAX 方法,不能使用 numpy 方法,因此您使用 np.array_equal 会破坏抽象。

不幸的是,它实际上没有任何替代品,因为没有在具体 Python 字典中查找抽象 JAX 值的机制。如果你想对 JAX 值进行字典查找,你应该避免转换,只使用 Python 循环:

keys = jnp.array([[get_key(x, A) for x in row] for row in a])

另一方面,我怀疑这更像是一个 XY problem;您的目标不是在 jax 转换中查找字典值,而是解决某些问题。也许你应该问一个关于如何解决问题的问题,而不是如何使用你尝试过的解决方案来解决问题。

但是如果您不想直接使用字典,与 JAX 兼容的替代 get_key 实现可能如下所示:

def get_key(val, my_dict):
  keys = jnp.array(list(my_dict.keys()))
  values = jnp.array(list(my_dict.values()))
  return keys[jnp.where((values == val).all(-1), size=1)]