TypeError: Argument 'MLP( # attributes num_neurons_per_layer = [4, 1] )' of type <class '__main__.MLP'> is not a valid JAX type

TypeError: Argument 'MLP( # attributes num_neurons_per_layer = [4, 1] )' of type <class '__main__.MLP'> is not a valid JAX type

我在 Jax 上测试基本模型时遇到了一些问题。例如,我正在尝试从 Jax 手动实现 value_and_grad() 函数以解决二进制 classification 问题。这是我的模型初始值设定项:

class MLP(nn.Module):
    num_neurons_per_layer: Sequence[int]

    @nn.compact
    def __call__(self, x):
        activation = x
        for i, num_neurons in enumerate(self.num_neurons_per_layer):
            activation = nn.Dense(num_neurons)(activation)
            if i != len(self.num_neurons_per_layer) - 1:
                activation = nn.relu(activation)
        return nn.sigmoid(activation)

这是我的 BCE 损失,它使用 vmap 更快地对样本进行批处理,所有样本都包裹在 jit:

def make_bce_loss(xs, ys):
    
    def bce_loss(params, model): 
        def cross_entropy(x, y):
            preds = model.apply(params, x)
            return y * jnp.log(preds) + (1 - y) * jnp.log(1 - preds)
        return -jnp.mean(jax.vmap(cross_entropy)(xs, ys), axis=0)

    return jax.jit(bce_loss)

bce_loss = make_bce_loss(X, y)
value_and_grad_fn = jax.value_and_grad(bce_loss)

然后我继续创建模型并初始化参数:

model = MLP(num_neurons_per_layer=[4, 1])
params = model.init(key, X)  # I create a jnp.array() to create X earlier on

当我测试 value_and_grad_fn(params, model) 的 jitted 版本时,出现以下错误:

TypeError: Argument 'MLP( # attributes num_neurons_per_layer = [4, 1] )' of type <class '__main__.MLP'> is not a valid JAX type.

我不确定我应该做什么来纠正这个问题。它抛出了关于 [4, 1] 的错误,但这些根本不参与计算,它们仅用于在 MLP class.

中初始化模型

Flax 模型必须标记为静态才能在 jit 中使用:

    return jax.jit(bce_loss, static_argnums=1)

看起来 flax 有一个问题需要改进此错误消息:https://github.com/google/flax/issues/853