使用 Jax 的偏导数?
Partial derivatives using Jax?
我对 Jax 文档感到困惑,这是我正在尝试做的事情:
def line(m,x,b):
return m*x + b
grad(line)(1,2,3)
错误:
---------------------------------------------------------------------------
FilteredStackTrace Traceback (most recent call last)
<ipython-input-48-d14b17620b30> in <module>()
3
----> 4 grad(line)(1,2,3)
FilteredStackTrace: TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating), but got int32. If you want to use integer-valued inputs, use vjp or set allow_int to True.
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
6 frames
/usr/local/lib/python3.7/dist-packages/jax/api.py in _check_input_dtype_revderiv(name, holomorphic, allow_int, x)
844 elif not allow_int and not (dtypes.issubdtype(aval.dtype, np.floating) or
845 dtypes.issubdtype(aval.dtype, np.complexfloating)):
--> 846 raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype that "
847 "is a sub-dtype of np.floating or np.complexfloating), "
848 f"but got {aval.dtype.name}. If you want to use integer-valued "
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating), but got int32. If you want to use integer-valued inputs, use vjp or set allow_int to True.
我参考的是官方tutorial代码:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
结果:
W_grad [-0.16965576 -0.8774648 -1.4901345 ]
我在这里做错了什么?我了解到 key
正在以某种重要方式使用,但我无法确定 why/how 是否有必要。要回答这个问题,请根据需要调整第一块中的代码以消除错误。
Jax 告诉您它不喜欢整数。 grad(line)(1.,2.,3.)
(使用浮点数)解决了这个问题。
我认为这里的错误很清楚:
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating), but got int32. If you want to use integer-valued inputs, use vjp or set allow_int to True.
要将 grad(line)(1,2,3)
与 Int32
一起使用,请将其更改为 grad(line, allow_int=True)(1,2,3)
我对 Jax 文档感到困惑,这是我正在尝试做的事情:
def line(m,x,b):
return m*x + b
grad(line)(1,2,3)
错误:
---------------------------------------------------------------------------
FilteredStackTrace Traceback (most recent call last)
<ipython-input-48-d14b17620b30> in <module>()
3
----> 4 grad(line)(1,2,3)
FilteredStackTrace: TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating), but got int32. If you want to use integer-valued inputs, use vjp or set allow_int to True.
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
6 frames
/usr/local/lib/python3.7/dist-packages/jax/api.py in _check_input_dtype_revderiv(name, holomorphic, allow_int, x)
844 elif not allow_int and not (dtypes.issubdtype(aval.dtype, np.floating) or
845 dtypes.issubdtype(aval.dtype, np.complexfloating)):
--> 846 raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype that "
847 "is a sub-dtype of np.floating or np.complexfloating), "
848 f"but got {aval.dtype.name}. If you want to use integer-valued "
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating), but got int32. If you want to use integer-valued inputs, use vjp or set allow_int to True.
我参考的是官方tutorial代码:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
结果:
W_grad [-0.16965576 -0.8774648 -1.4901345 ]
我在这里做错了什么?我了解到 key
正在以某种重要方式使用,但我无法确定 why/how 是否有必要。要回答这个问题,请根据需要调整第一块中的代码以消除错误。
Jax 告诉您它不喜欢整数。 grad(line)(1.,2.,3.)
(使用浮点数)解决了这个问题。
我认为这里的错误很清楚:
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating), but got int32. If you want to use integer-valued inputs, use vjp or set allow_int to True.
要将 grad(line)(1,2,3)
与 Int32
一起使用,请将其更改为 grad(line, allow_int=True)(1,2,3)