Error in meanSquaredError: Shapes 10,1 and 10,2 must match (tensorflow.js)

Error in meanSquaredError: Shapes 10,1 and 10,2 must match (tensorflow.js)

我的代码从 csv 加载数据。然后我建立一个模型并将数据传递给它。然后我尝试用数据训练我的模型。

此时出现上述错误。由于我对 javascript 的经验很少,所以我不知道去哪里搜索。我认为它与我的 .batch-call 有关。如果我将行更改为“}).batch(20);”错误更改为:"Shapes 20,1 and 20,2 must match"。在我的理解中,批处理是在 trainmodel-function 的 "batchsize" 参数中设置的。我不知所措,我的错误所在。我的数据集有 196 个特征列和一个标签列。

  async train(): Promise<any> {
  const csvUrl = '/assets/little.csv';
  const csvDataset = tf.data.csv(
  csvUrl,
  {
    columnConfigs: {
      quit: {
        isLabel: true
      }
    },
    delimiter:','
  });
   const numOfFeatures = (await csvDataset.columnNames()).length -1;      
   console.log(numOfFeatures);
   const flattenedDataset =
   csvDataset
   .map(({xs, ys}: any) =>
     { 
     // Convert xs(features) and ys(labels) from object form (keyed by
     // column name) to array form.
     return {xs:Object.values(xs), ys:Object.values(ys)};
     }).batch(10);    
   console.log(flattenedDataset.toArray());      

   const model = tf.sequential({
     layers: [
       tf.layers.dense({inputShape: [196], units: 100, activation: 'relu'}),
       tf.layers.dense({units: 100, activation: 'relu'}),
       tf.layers.dense({units: 100, activation: 'relu'}),        
       tf.layers.dense({units: 2, activation: 'softmax'}),        
     ]
   }); 
   tfvis.show.modelSummary({name: 'Model Summary'}, model);     
   await trainModel(model, flattenedDataset);
   console.log('Done Training');
   }
}

async function trainModel(model, flattenedDataset) {
  // Prepare the model for training.  
  model.compile({
    optimizer: tf.train.adam(),
    loss: tf.losses.meanSquaredError,
    metrics: ['mse'],
  });

  const batchSize = 32;
  const epochs = 50;

  return await model.fitDataset(flattenedDataset, {
    batchSize,
    epochs,
    shuffle: true,
    callbacks: tfvis.show.fitCallbacks(
      { name: 'Training Performance' },
      ['loss', 'mse'], 
      { height: 200, callbacks: ['onEpochEnd'] }
    )
});

最后一层有 units:2 而只有一列 quit 被设置为标签。

要么将另一列设置为标签,要么单元数应为1