TypeError: Argument 'MLP( # attributes num_neurons_per_layer = [4, 1] )' of type <class '__main__.MLP'> is not a valid JAX type
TypeError: Argument 'MLP( # attributes num_neurons_per_layer = [4, 1] )' of type <class '__main__.MLP'> is not a valid JAX type
我在 Jax 上测试基本模型时遇到了一些问题。例如,我正在尝试从 Jax 手动实现 value_and_grad()
函数以解决二进制 classification 问题。这是我的模型初始值设定项:
class MLP(nn.Module):
num_neurons_per_layer: Sequence[int]
@nn.compact
def __call__(self, x):
activation = x
for i, num_neurons in enumerate(self.num_neurons_per_layer):
activation = nn.Dense(num_neurons)(activation)
if i != len(self.num_neurons_per_layer) - 1:
activation = nn.relu(activation)
return nn.sigmoid(activation)
这是我的 BCE 损失,它使用 vmap
更快地对样本进行批处理,所有样本都包裹在 jit
:
def make_bce_loss(xs, ys):
def bce_loss(params, model):
def cross_entropy(x, y):
preds = model.apply(params, x)
return y * jnp.log(preds) + (1 - y) * jnp.log(1 - preds)
return -jnp.mean(jax.vmap(cross_entropy)(xs, ys), axis=0)
return jax.jit(bce_loss)
bce_loss = make_bce_loss(X, y)
value_and_grad_fn = jax.value_and_grad(bce_loss)
然后我继续创建模型并初始化参数:
model = MLP(num_neurons_per_layer=[4, 1])
params = model.init(key, X) # I create a jnp.array() to create X earlier on
当我测试 value_and_grad_fn(params, model)
的 jitted 版本时,出现以下错误:
TypeError: Argument 'MLP( # attributes num_neurons_per_layer = [4, 1] )' of type <class '__main__.MLP'> is not a valid JAX type.
我不确定我应该做什么来纠正这个问题。它抛出了关于 [4, 1]
的错误,但这些根本不参与计算,它们仅用于在 MLP
class.
中初始化模型
Flax 模型必须标记为静态才能在 jit 中使用:
return jax.jit(bce_loss, static_argnums=1)
看起来 flax
有一个问题需要改进此错误消息:https://github.com/google/flax/issues/853
我在 Jax 上测试基本模型时遇到了一些问题。例如,我正在尝试从 Jax 手动实现 value_and_grad()
函数以解决二进制 classification 问题。这是我的模型初始值设定项:
class MLP(nn.Module):
num_neurons_per_layer: Sequence[int]
@nn.compact
def __call__(self, x):
activation = x
for i, num_neurons in enumerate(self.num_neurons_per_layer):
activation = nn.Dense(num_neurons)(activation)
if i != len(self.num_neurons_per_layer) - 1:
activation = nn.relu(activation)
return nn.sigmoid(activation)
这是我的 BCE 损失,它使用 vmap
更快地对样本进行批处理,所有样本都包裹在 jit
:
def make_bce_loss(xs, ys):
def bce_loss(params, model):
def cross_entropy(x, y):
preds = model.apply(params, x)
return y * jnp.log(preds) + (1 - y) * jnp.log(1 - preds)
return -jnp.mean(jax.vmap(cross_entropy)(xs, ys), axis=0)
return jax.jit(bce_loss)
bce_loss = make_bce_loss(X, y)
value_and_grad_fn = jax.value_and_grad(bce_loss)
然后我继续创建模型并初始化参数:
model = MLP(num_neurons_per_layer=[4, 1])
params = model.init(key, X) # I create a jnp.array() to create X earlier on
当我测试 value_and_grad_fn(params, model)
的 jitted 版本时,出现以下错误:
TypeError: Argument 'MLP( # attributes num_neurons_per_layer = [4, 1] )' of type <class '__main__.MLP'> is not a valid JAX type.
我不确定我应该做什么来纠正这个问题。它抛出了关于 [4, 1]
的错误,但这些根本不参与计算,它们仅用于在 MLP
class.
Flax 模型必须标记为静态才能在 jit 中使用:
return jax.jit(bce_loss, static_argnums=1)
看起来 flax
有一个问题需要改进此错误消息:https://github.com/google/flax/issues/853