RNN class化失败,除非我使用嵌入。没有嵌入,所有预测都是相同的 class

RNN classification fails unless I use an embedding. Without embedding, all predictions are the same class

我正在尝试class化一个长度为 200 的向量列表 X,其中包含从长度为 100 的字典 vocab 中选择的整数值作为属于到 class 0 或 1。这是我的输入数据的示例:

X=[[1,22,77,54,51,...],[2,3,1,41,3,...],[12,17,31,4,12,...],...]
y=[0,1,1,...]

例如 np.array(X).shape=(1000,200)y.shape=(200,)。 classes 分成 50-50。我做了一个标准的 train_test 拆分为 (X,y) 和 (X_test,y_test)。

我的模型是:

from keras import layers as L
model = keras.models.Sequential()
model.add(L.Embedding(input_dim=100+1,output_dim=32,\
                      input_length=200))
model.add(L.SimpleRNN(64))           
model.add(L.Dense(1,activation='sigmoid'))
model.compile(optimizer='adam',loss='binary_crossentropy',\
              metrics=['accuracy'])

model.fit(X,y,batch_size=128,epochs=20,validation_data=(X_test,y_text))

当我将它拟合到训练和测试数据时,它工作得相当好。但是,我想尝试跳过嵌入,因为我有 "small" space 个特征 (9026)。我通过除以 9206. 对训练数据进行归一化,并尝试按如下方式构建简单的 RNN 模型:

model = keras.models.Sequential()
model.add(L.InputLayer(200,1))
model.add(L.SimpleRNN(64))           
model.add(L.Dense(1,activation='sigmoid'))
model.compile(optimizer='adam',loss='binary_crossentropy',\
              metrics=['accuracy'])
model.fit(X[:,:,np.newaxis],y,batch_size=128,epochs=20,validation_data=(X_test[:,:,np.newaxis],y_text))

我必须添加 np.newaxis 才能编译模型。当我将其与数据相匹配时,我总是得到 0.5 的训练和验证精度,这是 class 0 到 class 1 的分数。我尝试了不同的激活、不同的优化器、不同数量的RNN 中的单元、不同的批次大小、LSTM、GRU、添加 dropout、多层……没有任何效果。

我的问题是:

  1. 我有固定长度 (200) 的向量到 classify,词汇表只有 100 个特征。没有嵌入就不能做到这一点吗?

  2. 有没有人对让非嵌入模型进行实际训练有有用的建议?

循环层需要形状为 (batch_size, timesteps, input_dim) 的输入,其中 input_dim 是输入数据中类别的数量,并且这些类别已经过单热编码,例如[1, 3], [0, 2] 变为 [[0, 1, 0, 0], [0, 0, 0, 1]], [[1, 0, 0, 0], [0, 0, 1, 0]].

现在你的数据的形状是 (batch_size, timesteps) 并且是稀疏编码的,这意味着上面编码中 1 的位置由类别编号隐式给出。只需向数组添加一个新轴即可使其形状正确,因此 Keras 不会引发任何错误,但数据编码不正确,因此您的训练显然根本不起作用。

它实际上适用于 Embedding 层,因为与 Recurrent 层相反,嵌入层 期望给定形状和编码的输入 (比较 RNN with the one of Embedding).

的输入形状

要解决这个问题,你只需要one-hot encode your data. Keras provides the very convenient to_categorical util function for this, but you also might do it