使用 Keras 使用泊松采样标签提高 MLP 性能(对于 multi-class classification)

Improving MLP performance (for multi-class classification) with poisson sampled labels using Keras

我正在尝试使用完全连接的神经网络或多层感知器来执行多class class化:我的训练数据 (X) 是等长的不同 DNA 串。这些序列中的每一个都有一个与之关联的浮点值(例如 t_X),我用它来通过以下方式为我的数据模拟标签 (y)。 y ~ np.random.poisson(常数 * t_X)

训练完我的 Keras 模型(请参见下文)后,我制作了预测标签和测试标签的直方图,我面临的问题是我的模型似乎 class 错误地验证了很多序列,请参阅下面链接的图片。

Histogram link

我的训练数据如下所示:

X , Y  
CTATTACCTGCCCACGGTAAAGGCGTTCTGG,    1
TTTCTGCCCGCGGCCTGGCAATTGATACCGC,    6
TTTTTACACGCCTTGCGTAAAGCGGCACGGC,    4
TTGCTGCCTGGCCGATGGTCTATGCCGCTGC,    7

我一次性编码我的 Y 和我的 X 序列被转换成维度张量:(批量大小、序列长度、字符数),这些数字大约是 10,000 x 50 x 4

我的 keras 模型看起来像:

model = Sequential() 
model.add(Flatten())
model.add(Dense(100, activation='relu',input_shape=(50,4)))
model.add(Dropout(0.25))
model.add(Dense(50, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(len(one_hot_encoded_labels), activation='softmax'))

我尝试了以下不同的损失函数

#model.compile(loss='mean_squared_error',optimizer=Adam(lr=0.00001), metrics=['accuracy'])
#model.compile(loss='mean_squared_error',optimizer=Adam(lr=0.0001), metrics=['mean_absolute_error',r_square])
#model.compile(loss='kullback_leibler_divergence',optimizer=Adam(lr=0.00001), metrics=['categorical_accuracy'])
#model.compile(loss=log_poisson_loss,optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
#model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
model.compile(loss='poisson',optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])

损失表现合理;随着时代的增加,它会下降并变平。我尝试过不同的学习率、不同的优化器、每层中不同数量的神经元、不同数量的隐藏层和不同类型的正则化。

我认为我的模型总是将大多数预测标签放在测试数据的峰值附近,(请参阅链接的直方图),但它无法 class验证测试集中计数较少的序列.这是一个常见问题吗?

在不使用其他架构(如卷积或循环)的情况下,有谁知道我可以如何改进此模型的 class化性能?

Training data file

从您的直方图分布可以清楚地看出,您的测试数据集非常不平衡。我假设,你有相同的训练数据分布。那么这可能是 NN 表现不佳的原因,因为它没有足够的数据供许多 classes 学习这些特征。你可以尝试一些采样技术,这样它就可以比较每个 class.

之间的关系

这里有一个link,解释了这种不平衡数据集的各种方法。

其次,您可以通过交叉验证来检查模型的性能,您可以在其中轻松找到可约误差还是不可约误差。如果那是不可减少的错误,你就不能再改进了(你必须针对这种情况尝试另一种方法)。

第三,序列之间存在相关性。简单的前馈网络无法捕获这种关系。 Recurrent-network 可以捕获数据集中的此类依赖关系。这是简单的example。此示例适用于 binary-class,可以扩展为 multi-class,如您的情况。

对于loss-function的选择,完全是针对特定问题的。您可以 check this link,其中解释了何时以及哪种损失函数会有所帮助。