如何解决 JAX/Python 中的 ValueError `vector::reserve`?

How to resolve ValueError `vector::reserve` in JAX/Python?

编辑:GitHub 问题在这里:https://github.com/google/jax/issues/5190

我正在尝试使用 jit 优化以下函数:

@partial(jit, static_argnums=(0, 1,))
def coocurrence_helper(pairs: np.array, label_map: Dict) -> lil_matrix:
    uniques = lil_matrix(np.zeros((len(label_map), len(label_map))).astype("int32"))
    for item in pairs:
        if item[0]!=item[1]:
            uniques[label_map[item[0]], label_map[item[1]]] += 1
    return uniques

这里使用上面的套路:

def _get_pairwise_frequencies(
     data: pd.DataFrame, crosstab=False
    ) -> pd.DataFrame:
        values = data.stack()
        values.index = values.index.droplevel(1)
        values.name = "vals"
        values = optimize(values.to_frame())
        pair = optimize(values.join(values, rsuffix="_2"))
        label_map = dict()
        for lbl, each in enumerate(values.vals.unique()):
            label_map[each] = lbl
        if not crosstab:
            freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
            return ((freq / freq.sum(1).ravel()).astype(np.float32))
        else:
            freq = pd.crosstab(pair["vals"], pair["vals_2"])
            self.index = freq.index
            return csr_matrix((freq / freq.sum(1)).astype(np.float32))

但我收到以下错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-42-f8e638fc2bb6> in <module>
----> 1 _get_pairwise_frequencies(data)

<ipython-input-30-43adeb39c76c> in _get_pairwise_frequencies(data, crosstab)
     25             label_map[each] = lbl
     26         if not crosstab:
---> 27             freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
     28             return csr_matrix((freq / freq.sum(1).ravel()).astype(np.float32))
     29         else:

~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    369         return cache_miss(*args, **kwargs)[0]  # probably won't return
    370     else:
--> 371       return cpp_jitted_f(*args, **kwargs)
    372   f_jitted._cpp_jitted_f = cpp_jitted_f
    373 

ValueError: vector::reserve

问题的根源是什么?不使用 static_argnums 错误消息是

RuntimeError: Invalid argument: Unknown NumPy type O size 8

具有相同的回溯。

问题是您返回的 scipy.sparse.lil_matrix 不是有效的 JAX 类型。 JAX jit 装饰器不能用作任意 Python 代码的编译器;它旨在优化 JAX 数组上的操作序列。

在这种情况下,最好的处理方法可能是从您的函数中删除 @partial(jit, ...) 装饰器;如果您想在这里使用 JAX jit 编译,您首先必须重写代码以避免 scipy.sparse 矩阵并改用 JAX 数组。