我对 tensorflow/tfjs LSTM 输入形状和 LSTM 基本理解有疑问

I have a question on tensorflow/tfjs LSTM inputShape and also LSTM basic understaning

我的训练数据集是针对一个用户的

let training = [
    [[2019.1], [10]],
    [[2019.2], [2]],
    [[2019.4], [11]],
    [[2019.5], [31]]
]

对于这个简单的输入,我想获得下个月的预测。先了解LSTM时间序列。使用以下代码生成训练模型

let train_x = training.map(i => {return i[0]})
let train_y = training.map(j=> {return j[1]})

const model = tf.sequential();
**model.add(tf.layers.lstm({units: 128, returnSequences: false,  inputShape:[train_x.length]}));**
model.add(tf.layers.dropout(0.2))
model.add(tf.layers.dense({units: training.length, activation: 'softmax'}));

model.compile({loss: 'categoricalCrossentropy', optimizer: tf.train.rmsprop(0.002)});

**const xs = tf.tensor3d([train_x]);**
const ys = tf.tensor2d(train_y, [training.length, train_y[0].length]);

错误:

Error: Input 0 is incompatible with layer lstm_LSTM1: expected ndim=3, found ndim=2

问题是应该给出什么输入形状以及什么应该是 tf.tensor3d 输入。据我了解,我正在尝试这个简单的例子。在没有keras的情况下尝试

错误信息很简单:

Input 0 is incompatible with layer lstm_LSTM1: expected ndim=3, found ndim=2

lstm 层需要 3d 输入。这意味着 inputShape 应该是 [a, b],其中 ab 都是数字(a 也可以是 null)。有一个包含 2 个元素的序列。 a因此是1(单序列:我们希望层在进行预测之前看到的序列数;看这里的数据好像是1,但是可以改成不同的值) b2(每个序列 2 个元素)。

培训将是:

training = [
    [[2019.1, 10]],
    [[2019.2, 2]],
    [[2019.4], [11]],
    [[2019.5], [31]]
]
xs = tf.tensor(training).reshape([-1, 1, 2])

而lstm层变成如下:

model.add(tf.layers.lstm({units: 128, returnSequences: false,  inputShape:[1, 2]}));

瞧,整个模型如下所示:

const model = tf.sequential();
model.add(tf.layers.lstm({units: 128, returnSequences: false,  inputShape:[1, 2]}));
model.add(tf.layers.dropout(0.2))
model.add(tf.layers.dense({units: 20, activation: 'softmax'}));
model.summary()
model.compile({loss: 'categoricalCrossentropy', optimizer: tf.train.rmsprop(0.002)});
model.summary()
const training = [
        [[2019.1, 10]],
        [[2019.2, 2]],
        [[2019.4], [11]],
        [[2019.5], [31]]
    ]
const xs = tf.tensor(training).reshape([-1, 1, 2])
await model.fit(xs, tf.ones([4, 20]))
model.predict(tf.ones([1, 1, 2])).print()