LSTM 模型中的错误形状
Bad shape in LSTM model
我正在使用 tensorflow js 我有这段代码来构建我的循环神经网络模型以解决分类问题 3 类,大小为 250 的实例,。
当我尝试拟合我的模型时出现以下错误消息:
错误:检查目标时出错:预期 dense_Dense1 的形状为 [3],但得到的数组形状为 [4827,1].
我对在 tfjs 中构建自己的模型还很陌生,我想我搞砸了张量形状
PS:我的数据集包含 4827 个实例,我的 embeddingSize 是 32
function buildModel(maxLen, vocabularySize, embeddingSize, numClasses)
{
const model = tensorflow.sequential();
model.add(tensorflow.layers.embedding(
{
inputDim: vocabularySize,
outputDim: embeddingSize,//embeddingSize = 32
inputLength: maxLen//maxLen = 250
}));
model.add(tensorflow.layers.lstm({units: embeddingSize/*, returnSequences: true*/}));
model.add(tensorflow.layers.dense({units: numClasses, activation: 'softmax'}));//numClasses = 3
return model;
}
const history = await model.fit(data, labels, {
epochs: epochs,
batchSize: batchSize,
validationSplit: validationSplit,
callbacks: () =>
{
console.log("Coucou");
}
});
console.log(history);
谢谢
您需要通过将 false
返回到 lstm 层
来更改层维度
model.add(tensorflow.layers.lstm({units: embeddingSize, returnSequences: false}));
问题是我的数据标签为 un 1D 格式(0、1 或 2),而不是 3D 格式([1,0,0]、[0,1,0]、[0,0, 1])
我正在使用 tensorflow js 我有这段代码来构建我的循环神经网络模型以解决分类问题 3 类,大小为 250 的实例,。 当我尝试拟合我的模型时出现以下错误消息:
错误:检查目标时出错:预期 dense_Dense1 的形状为 [3],但得到的数组形状为 [4827,1].
我对在 tfjs 中构建自己的模型还很陌生,我想我搞砸了张量形状
PS:我的数据集包含 4827 个实例,我的 embeddingSize 是 32
function buildModel(maxLen, vocabularySize, embeddingSize, numClasses)
{
const model = tensorflow.sequential();
model.add(tensorflow.layers.embedding(
{
inputDim: vocabularySize,
outputDim: embeddingSize,//embeddingSize = 32
inputLength: maxLen//maxLen = 250
}));
model.add(tensorflow.layers.lstm({units: embeddingSize/*, returnSequences: true*/}));
model.add(tensorflow.layers.dense({units: numClasses, activation: 'softmax'}));//numClasses = 3
return model;
}
const history = await model.fit(data, labels, {
epochs: epochs,
batchSize: batchSize,
validationSplit: validationSplit,
callbacks: () =>
{
console.log("Coucou");
}
});
console.log(history);
谢谢
您需要通过将 false
返回到 lstm 层
model.add(tensorflow.layers.lstm({units: embeddingSize, returnSequences: false}));
问题是我的数据标签为 un 1D 格式(0、1 或 2),而不是 3D 格式([1,0,0]、[0,1,0]、[0,0, 1])