为什么 model.fit 需要二维张量?为什么 model.predict 不接受标量张量?

Why does model.fit require two-dimensional tensors? And why does model.predict not accept scalar tensors?

我在学习 TensorFlow.js 时注意到 model.fit 必须接受两个参数,输入和输出,以及一些配置。但是输入的是二维张量,写法如下:

let input = tf.tensor2d([1, 2, 3, 4, 5], [5, 1])

这看起来非常像一个一维张量,写成如下:

let input = tf.tensor1d([1, 2, 3, 4, 5])

而且由于二维张量实际上是 5 x 1,我决定用一维张量代替它。但是,这完全停止了程序的运行。那么是否有某种类型的代码表明输入必须是二维的?如果是,为什么?

关于多维张量的话题,我还注意到 model.predict 不能接受零维张量或标量。见下文

Working Code:

model.predict(tf.tensor1d([6]))

Not Working Code:

model.predict(tf.scalar(6))

如果有人能阐明这些限制背后的原因,将不胜感激。

二维张量不是一维张量。 tf.tensor2d([1, 2, 3, 4, 5], [5, 1]) 不是 tf.tensor1d([1, 2, 3, 4, 5])。一个可以转换为另一个,但这并不意味着它们是平等的。

model.fit 将张量或秩 2 或更高作为参数。这个张量可以看作是一个元素数组,其形状被赋予模型的输入。该模型的 inputShape 至少是等级 1,这使得 model.fit 参数至少为 2(1+1 它始终是 inputShape 的等级 + 1)。

由于 model.fit 和 model.predict 将相同等级的张量作为参数,因此 model.predict 参数是等级 2 或更高的张量,原因与上述相同。

但是,model.predict(tf.tensor1d([6])) 有效。这样做是因为在内部,tensorflow.js 会将一维张量转换为二维张量。形状为 [6] 的初始张量将被转换为形状为 [6, 1] 的张量。

model.predict(tf.tensor1d([6])) 
// will work because it is a 1D tensor 
// and only in the case where the model first layer inputShape is [1]

model.predict(tf.tensor2d([[6]])) 
// will also work
// One rank higher than the inputShape and of shape [1, ...InputShape]

model.predict(tf.scalar(6)) // will not work

const model = tf.sequential(
    {layers: [tf.layers.dense({units: 1, inputShape: [1]})]});
model.predict(tf.ones([3])).print(); // works
model.predict(tf.ones([3, 1])).print(); // works
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
  </head>

  <body>
  </body>
</html>

const model = tf.sequential(
    {layers: [tf.layers.dense({units: 1, inputShape: [2]})]});
model.predict(tf.ones([2, 2])).print(); // works
model.predict(tf.ones([2])).print(); // will not work
   // because [2] is converted to [2, 1]
   // whereas the model is expecting an input of shape [b, 2] with b an integer
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
  </head>

  <body>
  </body>
</html>