Tensorflow.js 预测的上限好像是1?

Tensorflow.js prediction seems to have an upper limit of 1?

我正在玩弄一些从 youtube 教程中获得的预测花卉数据的 tensorflow 代码。这是脚本(训练数据分配给变量"iris",测试数据分配给变量"irisTesting":

const trainingData = tf.tensor2d(iris.map(item => [
    item.sepal_length, item.petal_length, item.petal_width,
]));
const outputData = tf.tensor2d(iris.map(item => [
    item.species === "setosa" ? 1 : 0,
    item.species === "virginica" ? 1 : 0,
    item.species === "versicolor" ? 1 : 0,
    item.sepal_width
]));
const testingData = tf.tensor2d(irisTesting.map(item => [
    item.sepal_length, item.petal_length, item.petal_width
]));

const model = tf.sequential();

model.add(tf.layers.dense({
    inputShape: [3],
    activation: "sigmoid",
    units: 5,
}));
model.add(tf.layers.dense({
    inputShape: [5],
    activation: "sigmoid",
    units: 4,
}));
model.add(tf.layers.dense({
    activation: "sigmoid",
    units: 4,
}));
model.compile({
    loss: "meanSquaredError",
    optimizer: tf.train.adam(.06),
});
const startTime = Date.now();
model.fit(trainingData, outputData, {epochs: 100})
    .then((history) => {
         //console.log(history);
        console.log("Done training in " + (Date.now()-startTime) / 1000 + " seconds.");
        model.predict(testingData).print();
    });

当控制台打印预测的 sepal_width 时,它的上限似乎为 1。训练数据的 sepal_width 值远远超过 1,但这里是记录的数据:

Tensor
    [[0.9561102, 0.0028415, 0.0708825, 0.9997129],
     [0.0081552, 0.9410981, 0.0867947, 0.999761 ],
     [0.0346453, 0.1170913, 0.8383155, 0.9999373]]

最后(第四)列是预测的 sepal_width 值。预测值应该大于 1,但似乎有什么东西阻止它大于 1。

这是原代码: https://gist.github.com/learncodeacademy/a96d80a29538c7625652493c2407b6be

你最后一层的激活函数是sigmoid

Sigmoid 函数如下所示:

source

正如你所看到的,它被限制在 0 到 1 的范围内。所以如果你想要其他输出值,你需要相应地调整你的最后一个激活函数。

您在最后一层使用 S 形激活函数来预测 sepal_width。 Sigmoid 是介于 0 和 1 之间的连续函数。请参阅 Wikipedia 以获得更详尽的解释。

如果你想预测 sepal_width,你应该尝试使用不同的激活函数。有关可用激活函数的列表,您可以查看 Tensorflow's API page(这是针对 Python 版本,但对于 JavaScript 版本应该类似)。您可以尝试 'softplus''relu' 甚至 'linear',但我不能说这些是否适合您的应用程序。尝试并试验看看哪个是最好的。

来自 here 的原始代码解决了 class 化问题。在你的 outputData 中添加 item.sepal_width 是没有意义的,因为它不是另一个 class.