Keras 密集层输出 - 形状错误
Keras Dense Layer Output - Shape Error
我正在使用 LSTM 处理 NLP 问题。问题是具有 3 类 (1,2 & 3) 的多类分类。因此,我使用以下代码转换了目标 类:y_train=to_catgorical(y_train)
,对于 y_test
也是如此。
但是在输出dense层写model.add((Dense(3,activation='softmax'))
的时候,出现如下错误:
Error when checking target: expected dense_1 to have shape (None, 3) but got array with shape (658118, 4)
然而,当我将其固定为 model.add((Dense(4,activation='softmax'))
,即 n+1(要预测的 n=类 的数量)时,它起作用了。 但在 Keras 示例中,他们使用了 cifar10 数据集,并将 类 的数量用作 10,而不是 11,并且它有效。
同样在二进制分类的情况下(2 类 被预测)我们只使用 1 个输出,即 model.add(Dense(1,"sigmoid"))
。
我已经经历了 this post 但仍然无法为这件事找到令人信服的逻辑,所以我想通过 Keras 中的密集层输出形状来清除这个概念。
P.S。我的理解是,Keras 从 0:num_classes
考虑 类,所以我们必须再进行一次转换。但是要问一件事,'0' 向量层必须保持未使用状态,对吗?在 cifar10 数据集的情况下,类 是 0:9
,这就是 num_classes=10 起作用的原因吗?如果是这种情况,那么如果我们必须预测 3 类 说 (0,1 & 2) 那么我们可以使用 num_classes=3 对吗?
也许需要注意的是使用Keras框架提供了一个one hot encoding功能:
from keras.datasets import cifar10
from keras.utils import np_utils
y_label_train_OneHot = np_utils.to_categorical(y_label_train)
y_label_test_OneHot = np_utils.to_categorical(y_label_test)
可以看到这段代码code
总之,我个人认为使用不同的功能会导致不同的分类。
我正在使用 LSTM 处理 NLP 问题。问题是具有 3 类 (1,2 & 3) 的多类分类。因此,我使用以下代码转换了目标 类:y_train=to_catgorical(y_train)
,对于 y_test
也是如此。
但是在输出dense层写model.add((Dense(3,activation='softmax'))
的时候,出现如下错误:
Error when checking target: expected dense_1 to have shape (None, 3) but got array with shape (658118, 4)
然而,当我将其固定为 model.add((Dense(4,activation='softmax'))
,即 n+1(要预测的 n=类 的数量)时,它起作用了。 但在 Keras 示例中,他们使用了 cifar10 数据集,并将 类 的数量用作 10,而不是 11,并且它有效。
同样在二进制分类的情况下(2 类 被预测)我们只使用 1 个输出,即 model.add(Dense(1,"sigmoid"))
。
我已经经历了 this post 但仍然无法为这件事找到令人信服的逻辑,所以我想通过 Keras 中的密集层输出形状来清除这个概念。
P.S。我的理解是,Keras 从 0:num_classes
考虑 类,所以我们必须再进行一次转换。但是要问一件事,'0' 向量层必须保持未使用状态,对吗?在 cifar10 数据集的情况下,类 是 0:9
,这就是 num_classes=10 起作用的原因吗?如果是这种情况,那么如果我们必须预测 3 类 说 (0,1 & 2) 那么我们可以使用 num_classes=3 对吗?
也许需要注意的是使用Keras框架提供了一个one hot encoding功能:
from keras.datasets import cifar10
from keras.utils import np_utils
y_label_train_OneHot = np_utils.to_categorical(y_label_train)
y_label_test_OneHot = np_utils.to_categorical(y_label_test)
可以看到这段代码code 总之,我个人认为使用不同的功能会导致不同的分类。