如何解决 JAX/Python 中的 ValueError `vector::reserve`?
How to resolve ValueError `vector::reserve` in JAX/Python?
编辑:GitHub 问题在这里:https://github.com/google/jax/issues/5190
我正在尝试使用 jit 优化以下函数:
@partial(jit, static_argnums=(0, 1,))
def coocurrence_helper(pairs: np.array, label_map: Dict) -> lil_matrix:
uniques = lil_matrix(np.zeros((len(label_map), len(label_map))).astype("int32"))
for item in pairs:
if item[0]!=item[1]:
uniques[label_map[item[0]], label_map[item[1]]] += 1
return uniques
这里使用上面的套路:
def _get_pairwise_frequencies(
data: pd.DataFrame, crosstab=False
) -> pd.DataFrame:
values = data.stack()
values.index = values.index.droplevel(1)
values.name = "vals"
values = optimize(values.to_frame())
pair = optimize(values.join(values, rsuffix="_2"))
label_map = dict()
for lbl, each in enumerate(values.vals.unique()):
label_map[each] = lbl
if not crosstab:
freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
return ((freq / freq.sum(1).ravel()).astype(np.float32))
else:
freq = pd.crosstab(pair["vals"], pair["vals_2"])
self.index = freq.index
return csr_matrix((freq / freq.sum(1)).astype(np.float32))
但我收到以下错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-42-f8e638fc2bb6> in <module>
----> 1 _get_pairwise_frequencies(data)
<ipython-input-30-43adeb39c76c> in _get_pairwise_frequencies(data, crosstab)
25 label_map[each] = lbl
26 if not crosstab:
---> 27 freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
28 return csr_matrix((freq / freq.sum(1).ravel()).astype(np.float32))
29 else:
~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
369 return cache_miss(*args, **kwargs)[0] # probably won't return
370 else:
--> 371 return cpp_jitted_f(*args, **kwargs)
372 f_jitted._cpp_jitted_f = cpp_jitted_f
373
ValueError: vector::reserve
问题的根源是什么?不使用 static_argnums
错误消息是
RuntimeError: Invalid argument: Unknown NumPy type O size 8
具有相同的回溯。
问题是您返回的 scipy.sparse.lil_matrix
不是有效的 JAX 类型。 JAX jit
装饰器不能用作任意 Python 代码的编译器;它旨在优化 JAX 数组上的操作序列。
在这种情况下,最好的处理方法可能是从您的函数中删除 @partial(jit, ...)
装饰器;如果您想在这里使用 JAX jit 编译,您首先必须重写代码以避免 scipy.sparse
矩阵并改用 JAX 数组。
编辑:GitHub 问题在这里:https://github.com/google/jax/issues/5190
我正在尝试使用 jit 优化以下函数:
@partial(jit, static_argnums=(0, 1,))
def coocurrence_helper(pairs: np.array, label_map: Dict) -> lil_matrix:
uniques = lil_matrix(np.zeros((len(label_map), len(label_map))).astype("int32"))
for item in pairs:
if item[0]!=item[1]:
uniques[label_map[item[0]], label_map[item[1]]] += 1
return uniques
这里使用上面的套路:
def _get_pairwise_frequencies(
data: pd.DataFrame, crosstab=False
) -> pd.DataFrame:
values = data.stack()
values.index = values.index.droplevel(1)
values.name = "vals"
values = optimize(values.to_frame())
pair = optimize(values.join(values, rsuffix="_2"))
label_map = dict()
for lbl, each in enumerate(values.vals.unique()):
label_map[each] = lbl
if not crosstab:
freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
return ((freq / freq.sum(1).ravel()).astype(np.float32))
else:
freq = pd.crosstab(pair["vals"], pair["vals_2"])
self.index = freq.index
return csr_matrix((freq / freq.sum(1)).astype(np.float32))
但我收到以下错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-42-f8e638fc2bb6> in <module>
----> 1 _get_pairwise_frequencies(data)
<ipython-input-30-43adeb39c76c> in _get_pairwise_frequencies(data, crosstab)
25 label_map[each] = lbl
26 if not crosstab:
---> 27 freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
28 return csr_matrix((freq / freq.sum(1).ravel()).astype(np.float32))
29 else:
~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
369 return cache_miss(*args, **kwargs)[0] # probably won't return
370 else:
--> 371 return cpp_jitted_f(*args, **kwargs)
372 f_jitted._cpp_jitted_f = cpp_jitted_f
373
ValueError: vector::reserve
问题的根源是什么?不使用 static_argnums
错误消息是
RuntimeError: Invalid argument: Unknown NumPy type O size 8
具有相同的回溯。
问题是您返回的 scipy.sparse.lil_matrix
不是有效的 JAX 类型。 JAX jit
装饰器不能用作任意 Python 代码的编译器;它旨在优化 JAX 数组上的操作序列。
在这种情况下,最好的处理方法可能是从您的函数中删除 @partial(jit, ...)
装饰器;如果您想在这里使用 JAX jit 编译,您首先必须重写代码以避免 scipy.sparse
矩阵并改用 JAX 数组。