如何用multi-class SVM实现k折交叉验证

How to implement k-fold cross validation with multi-class SVM

我正在研究年龄预测主题,我可以实现具有 11 classes 的多 class SVM,方法是训练每个具有正数与所有其余部分的正数,如图所示 here and here。 但是问题出在for个循环,如下图,训练数据需要11个循环:

for k = 1:numClasses
    %Vectorized statement that binarizes Group
    %where 1 is the current class and 0 is all other classes
    G_x_All = (train_label == u(k));
    G_x_All = double (G_x_All);
    SVMStruct{k} = svmtrain(Data_Set, G_x_All);
end

然后,数据classification还需要为每个图像循环11次:

for j = 1:total_images
  for k = 1:numClasses
      if(svmclassify(SVMStruct{k}, Test_Img(j,:)));
          break;
      end
  end
  Age (j) = u(k); % Put the number of correct class in Age vector
end

我的新手问题是,在所有这些循环之后我如何制作 k-fold cross validation

编辑 ::>

这是根据zelanix先生的建议最后更新的代码,但我得到了不好的结果。你能帮我提高它的性能吗?

u = unique(train_label);
numClasses = length (u);
N = size (Data_Set,1)
A = 10;
indices = crossvalind('Kfold', N, A);
cp = classperf (train_label);

for i = 1:A
    Test = (indices == i); 
    Train = ~Test;         
    SVMStruct = cell(numClasses, 1); % Clear data structure.

    % Build models
  for k = 1:numClasses
    %Vectorized statement that binarizes Group
    %where 1 is the current class and 0 is all other classes
    G_x_All = (train_label == u(k));
    G_x_All = double (G_x_All);
    SVMStruct{k} = svmtrain(Data_Set (Train,:), G_x_All(Train,:));
  end

  Age = NaN(size(Data_Set, 1), 1);

  % Classify test cases
  for k = 1:numClasses
      if(svmclassify(SVMStruct{k}, Data_Set(Test,:)));
          break;
      end
  end
   Age = u(k);
   if Age == 1
       disp ('Under 10 years old');
   elseif Age == 10
       disp ('Between 10 and 20 years old');
   elseif Age == 20
       disp ('Between 20 and 30 years old');
   elseif Age == 30
       disp ('Between 30 and 40 years old');
   elseif Age == 40
       disp ('Between 40 and 50 years old');
   elseif Age == 50
       disp ('Between 50 and 60 years old');
   elseif Age == 60
       disp ('Upper 60 years old');
   else
       disp ('Unknown');
   end

classperf(cp, Age, Test);
    disp (i)
end
cp.CorrectRate

请注意,我将标签数量从 11 个减少到了 7 个。

您需要的一般结构如下(假设您的数据在变量 your_data 中,大小为 N x M 其中 N 是样本和 M 是特征的数量,您的 class 标签在大小为 M x 1 的变量 your_classes 中:

K = 10; % The number of folds
N = size(your_data, 1); % The number of data samples to train / test
idx = crossvalind('Kfold', N, K)

% your_classes should contain the class between 1 and numClasses.
cp = classperf(your_classes);

for i = 1:K
    Data_Set = your_data(idx ~= i, :); % The data to train on, 90% of the total.
    train_label = your_classes(idx ~= i, :); % The class labels of your training data.
    Test_Img = your_data(idx == i, :); % The data to test on, 10% of the total.
    test_label = your_classes(idx == i, :); % The class labels of your test data.

    SVMStruct = cell(numClasses, 1); % Clear data structure.

    % Your training routine, copied verbatim
    for k = 1:numClasses
        %Vectorized statement that binarizes Group
        %where 1 is the current class and 0 is all other classes
        G_x_All = (train_label == u(k));
        G_x_All = double (G_x_All);
        SVMStruct{k} = svmtrain(Data_Set, G_x_All);
    end

    Age = NaN(size(Test_Img, 1), 1);

    % Your test routine, copied (almost) verbatim
    for j = 1:size(Test_Img, 1)
      for k = 1:numClasses
          if(svmclassify(SVMStruct{k}, Test_Img(j,:)));
              break;
          end
      end
      Age(j) = u(k); % Put the number of correct class in Age vector
    end

    cp = classperf(cp, Age, idx == i);
end

cp.CorrectRate

这是未经测试的,我不确定您的 classification 是如何工作的。您似乎在第一个匹配的 class 化时中断,这可能不是正确的,或者实际上是最有可能的。您还需要一些方法来记录它并将其与 test_label 中的真实 class 标签相匹配。我建议您查看 classperf 函数,但这是一个单独的问题。

另请注意,matlab在fitcecoc函数中内置了multiclass SVM classification,可能更适合您的需求。

EDIT 更新代码的问题(正如我上面提到的)与您的 classification 方法有关。您遍历并测试样本是否属于每个 class 并在第一个匹配时中断。这不太可能是最有可能 class化,所以我对你得到糟糕的结果并不感到惊讶。

您的样本可能与第一个模型匹配,但只有很小的差距,但它并没有到达匹配得很好的模型。它不太可能通过前几个 classifications,如果它到了最后并且没有任何 classifier 匹配,你会怎么做?

Multiclass class通常通过选择 class 来实现使用 SVM 的多 class 化,其中样本 class 化的可能性最高(在所有 classes 在它测试为阳性的地方,选择离决策边界最远的那个 - fitcecoc 在内部执行此操作)。但这无法手动使用 svmclassify,因为它不会让您访问这些详细信息。

fitcecoc 仍然不能让您访问您可能需要的所有值,所以如果您真的想手动执行此操作,那么我建议您查看 libsvm,否则使用fitcecoc.

如评论中所述,svmtrainsvmclassify 现在已被弃用 - libsvm 还提供了更大的调优可能性和使用内置 MATLAB 实现无法实现的性能。

顺便说一句,multiclass 逻辑回归更容易理解,也可以获得很好的结果。