为什么当我的函数 np.power 中有一个 np.power 时不能给我导数?
why when i have a np.power in my function jax.grad can't give me the derivitives?
我想训练一个简单的线性模型。下面的 x 和 y 是我的数据。
import numpy as np
x = np.linspace(0,1,100)
y = 2 * x + 3 + np.random.randn(100)
f 是计算所有数据的均方误差的函数。
def f(params, x, y):
return np.mean(np.power((params['w'] * x + params['b'])-y , 2))
from jax import grad
df = grad(f)
params = dict()
#initialize parameters
params['w'] = 2.4
params['b'] = 10.
df(params, x, y) # I will do this in a loop (implementing gradient decent part
这给了我一个错误:
FilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray
当我清除 np.power
代码时。为什么?
JAX 无法计算 numpy
函数的梯度,但它可以计算 jax.numpy
函数的梯度。如果您根据 jax.numpy
重写代码,它应该适合您:
import numpy as np
x = np.linspace(0,1,100)
y = 2 * x + 3 + np.random.randn(100)
import jax.numpy as jnp
def f(params, x, y):
return jnp.mean(jnp.power((params['w'] * x + params['b'])-y , 2))
from jax import grad
df = grad(f)
params = dict()
params['w'] = 2.4
params['b'] = 10.
df(params, x, y)
# {'b': DeviceArray(14.661432, dtype=float32),
# 'w': DeviceArray(7.3792152, dtype=float32)}
您可以在 TracerArrayConversionError
documentation page 中阅读更多详细信息。
我想训练一个简单的线性模型。下面的 x 和 y 是我的数据。
import numpy as np
x = np.linspace(0,1,100)
y = 2 * x + 3 + np.random.randn(100)
f 是计算所有数据的均方误差的函数。
def f(params, x, y):
return np.mean(np.power((params['w'] * x + params['b'])-y , 2))
from jax import grad
df = grad(f)
params = dict()
#initialize parameters
params['w'] = 2.4
params['b'] = 10.
df(params, x, y) # I will do this in a loop (implementing gradient decent part
这给了我一个错误:
FilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray
当我清除 np.power
代码时。为什么?
JAX 无法计算 numpy
函数的梯度,但它可以计算 jax.numpy
函数的梯度。如果您根据 jax.numpy
重写代码,它应该适合您:
import numpy as np
x = np.linspace(0,1,100)
y = 2 * x + 3 + np.random.randn(100)
import jax.numpy as jnp
def f(params, x, y):
return jnp.mean(jnp.power((params['w'] * x + params['b'])-y , 2))
from jax import grad
df = grad(f)
params = dict()
params['w'] = 2.4
params['b'] = 10.
df(params, x, y)
# {'b': DeviceArray(14.661432, dtype=float32),
# 'w': DeviceArray(7.3792152, dtype=float32)}
您可以在 TracerArrayConversionError
documentation page 中阅读更多详细信息。