使用 Keras 使用泊松采样标签提高 MLP 性能(对于 multi-class classification)
Improving MLP performance (for multi-class classification) with poisson sampled labels using Keras
我正在尝试使用完全连接的神经网络或多层感知器来执行多class class化:我的训练数据 (X) 是等长的不同 DNA 串。这些序列中的每一个都有一个与之关联的浮点值(例如 t_X),我用它来通过以下方式为我的数据模拟标签 (y)。
y ~ np.random.poisson(常数 * t_X)。
训练完我的 Keras 模型(请参见下文)后,我制作了预测标签和测试标签的直方图,我面临的问题是我的模型似乎 class 错误地验证了很多序列,请参阅下面链接的图片。
我的训练数据如下所示:
X , Y
CTATTACCTGCCCACGGTAAAGGCGTTCTGG, 1
TTTCTGCCCGCGGCCTGGCAATTGATACCGC, 6
TTTTTACACGCCTTGCGTAAAGCGGCACGGC, 4
TTGCTGCCTGGCCGATGGTCTATGCCGCTGC, 7
我一次性编码我的 Y 和我的 X 序列被转换成维度张量:(批量大小、序列长度、字符数),这些数字大约是 10,000 x 50 x 4
我的 keras 模型看起来像:
model = Sequential()
model.add(Flatten())
model.add(Dense(100, activation='relu',input_shape=(50,4)))
model.add(Dropout(0.25))
model.add(Dense(50, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(len(one_hot_encoded_labels), activation='softmax'))
我尝试了以下不同的损失函数
#model.compile(loss='mean_squared_error',optimizer=Adam(lr=0.00001), metrics=['accuracy'])
#model.compile(loss='mean_squared_error',optimizer=Adam(lr=0.0001), metrics=['mean_absolute_error',r_square])
#model.compile(loss='kullback_leibler_divergence',optimizer=Adam(lr=0.00001), metrics=['categorical_accuracy'])
#model.compile(loss=log_poisson_loss,optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
#model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
model.compile(loss='poisson',optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
损失表现合理;随着时代的增加,它会下降并变平。我尝试过不同的学习率、不同的优化器、每层中不同数量的神经元、不同数量的隐藏层和不同类型的正则化。
我认为我的模型总是将大多数预测标签放在测试数据的峰值附近,(请参阅链接的直方图),但它无法 class验证测试集中计数较少的序列.这是一个常见问题吗?
在不使用其他架构(如卷积或循环)的情况下,有谁知道我可以如何改进此模型的 class化性能?
从您的直方图分布可以清楚地看出,您的测试数据集非常不平衡。我假设,你有相同的训练数据分布。那么这可能是 NN 表现不佳的原因,因为它没有足够的数据供许多 classes 学习这些特征。你可以尝试一些采样技术,这样它就可以比较每个 class.
之间的关系
这里有一个link,解释了这种不平衡数据集的各种方法。
其次,您可以通过交叉验证来检查模型的性能,您可以在其中轻松找到可约误差还是不可约误差。如果那是不可减少的错误,你就不能再改进了(你必须针对这种情况尝试另一种方法)。
第三,序列之间存在相关性。简单的前馈网络无法捕获这种关系。 Recurrent-network
可以捕获数据集中的此类依赖关系。这是简单的example。此示例适用于 binary-class,可以扩展为 multi-class
,如您的情况。
对于loss-function
的选择,完全是针对特定问题的。您可以 check this link,其中解释了何时以及哪种损失函数会有所帮助。
我正在尝试使用完全连接的神经网络或多层感知器来执行多class class化:我的训练数据 (X) 是等长的不同 DNA 串。这些序列中的每一个都有一个与之关联的浮点值(例如 t_X),我用它来通过以下方式为我的数据模拟标签 (y)。 y ~ np.random.poisson(常数 * t_X)。
训练完我的 Keras 模型(请参见下文)后,我制作了预测标签和测试标签的直方图,我面临的问题是我的模型似乎 class 错误地验证了很多序列,请参阅下面链接的图片。
我的训练数据如下所示:
X , Y
CTATTACCTGCCCACGGTAAAGGCGTTCTGG, 1
TTTCTGCCCGCGGCCTGGCAATTGATACCGC, 6
TTTTTACACGCCTTGCGTAAAGCGGCACGGC, 4
TTGCTGCCTGGCCGATGGTCTATGCCGCTGC, 7
我一次性编码我的 Y 和我的 X 序列被转换成维度张量:(批量大小、序列长度、字符数),这些数字大约是 10,000 x 50 x 4
我的 keras 模型看起来像:
model = Sequential()
model.add(Flatten())
model.add(Dense(100, activation='relu',input_shape=(50,4)))
model.add(Dropout(0.25))
model.add(Dense(50, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(len(one_hot_encoded_labels), activation='softmax'))
我尝试了以下不同的损失函数
#model.compile(loss='mean_squared_error',optimizer=Adam(lr=0.00001), metrics=['accuracy'])
#model.compile(loss='mean_squared_error',optimizer=Adam(lr=0.0001), metrics=['mean_absolute_error',r_square])
#model.compile(loss='kullback_leibler_divergence',optimizer=Adam(lr=0.00001), metrics=['categorical_accuracy'])
#model.compile(loss=log_poisson_loss,optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
#model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
model.compile(loss='poisson',optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
损失表现合理;随着时代的增加,它会下降并变平。我尝试过不同的学习率、不同的优化器、每层中不同数量的神经元、不同数量的隐藏层和不同类型的正则化。
我认为我的模型总是将大多数预测标签放在测试数据的峰值附近,(请参阅链接的直方图),但它无法 class验证测试集中计数较少的序列.这是一个常见问题吗?
在不使用其他架构(如卷积或循环)的情况下,有谁知道我可以如何改进此模型的 class化性能?
从您的直方图分布可以清楚地看出,您的测试数据集非常不平衡。我假设,你有相同的训练数据分布。那么这可能是 NN 表现不佳的原因,因为它没有足够的数据供许多 classes 学习这些特征。你可以尝试一些采样技术,这样它就可以比较每个 class.
之间的关系这里有一个link,解释了这种不平衡数据集的各种方法。
其次,您可以通过交叉验证来检查模型的性能,您可以在其中轻松找到可约误差还是不可约误差。如果那是不可减少的错误,你就不能再改进了(你必须针对这种情况尝试另一种方法)。
第三,序列之间存在相关性。简单的前馈网络无法捕获这种关系。 Recurrent-network
可以捕获数据集中的此类依赖关系。这是简单的example。此示例适用于 binary-class,可以扩展为 multi-class
,如您的情况。
对于loss-function
的选择,完全是针对特定问题的。您可以 check this link,其中解释了何时以及哪种损失函数会有所帮助。