Deep Learning Toolbox 中 DBN 的错误结果

Bad results from DBN in Deep Learning Toolbox

我要运行this example。当我使用 mnist_uint8 数据时,我可以很好地 运行 这段代码。但是如果我运行一个模型(比如DBN.m)使用我自己的数据,这个代码:

[er, bad] = nntest(nn, test_x, test_y); 

将运行什么都没有,er是零。为什么会这样?我的训练数据的输入大小是320*200,输出是320*1。

编辑:添加代码和数据文件

load dataX 
load dataY 
load pdataX 
load pdataY
train_x=dataX/100
test_x=pdataX/100
pdataY(find(pdataY(:,:)<=20))=0;
pdataY(find(pdataY(:,:)>20))=1;
dataY(find(dataY(:,:)<=20))=0;
dataY(find(dataY(:,:)>20))=1;
train_y=dataY
test_y=pdataY
rand('state',0); 
dbn.sizes = [100 40];

%train a 100-40 hidden unit DBN 
opts.numepochs = 1;
opts.batchsize = 40;
opts.momentum = 0;
opts.alpha = 1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);

%unfold dbn to nn nn = dbnunfoldtonn(dbn, 1);
nn.activation_function = 'sigm';
%train nn opts.numepochs = 1;
opts.batchsize = 40;
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);

这是数据

https://mega.co.nz/#!9V5wmKYK!q3nAvrzKZCT_Q3Ae-DDNAGDnV57b6Pzq6gtf01w0lD8

编辑:

经过多次讨论(见评论),问题是目标(y)需要使用 one-of-N 编码格式进行训练和测试。例如 [1 0] 表示 class 1,[0 1] 表示 class 2。修改后的代码产生的基本错误率为 0.2125。进一步调整和架构更改应该会产生更好的结果。

clear all

load dataX 
load dataY 
load pdataX 
load pdataY
train_x=dataX/100;
test_x=pdataX/100;
pdataY(find(pdataY(:,:)<=20))=0;
pdataY(find(pdataY(:,:)>20))=1;
dataY(find(dataY(:,:)<=20))=0;
dataY(find(dataY(:,:)>20))=1;
train_y=dataY
test_y=pdataY

% Add dimension for one-of-N encoding
train_y(:,2) = 1-train_y(:,1);
test_y(:,2) = 1-test_y(:,1);

rand('state',0)
dbn.sizes = [100 40];

%train a 100-40 hidden unit DBN
opts.numepochs = 2;
opts.batchsize = 40;
opts.momentum = 0;
opts.alpha = 1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);

%unfold dbn to nn
nn = dbnunfoldtonn(dbn, 2);
nn.activation_function = 'sigm';

%train nn
opts.numepochs = 100;
opts.batchsize = 40;
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);

原答案:

我假设您的训练数据是 200 个特征和 320 个训练示例。假设您正确地训练了它,那么您可能需要执行特征缩减。我知道 MNIST 数据集上的 ML 算法 运行 很流行使用主成分分析(请参阅 Matlab 函数 pca())对其进行预处理以截断某些特征。请 post 更多代码让我们真正看到问题。