在 java 上使用 SMOTE 会引发 Comparison 方法违反其一般契约

Using SMOTE on java raises Comparison method violates its general contract


我正在 java 做一个项目,我需要使用 Weka 的 API。我使用 Maven 来管理依赖项,特别是,我有以下一个:
<dependency>
   <groupId>nz.ac.waikato.cms.weka</groupId>
   <artifactId>weka-stable</artifactId>
   <version>3.8.5</version>
</dependency>

在这个版本中,SMOTEclass没有保留,但我确实需要它;这就是为什么我还在 pom.xml 中添加了以下依赖项:

<dependency>
   <groupId>nz.ac.waikato.cms.weka</groupId>
   <artifactId>SMOTE</artifactId>
   <version>1.0.2</version>
</dependency>

在我的 Java 代码中,我还尝试开发 WalkForward 验证技术:我可以同时准备 训练集 每个步骤的测试集,所以我可以在循环中使用它们,我所做的如下:

for (...){
   var filtered = new FilteredClassifier();
   var smote = new SMOTE();
   filtered.setFilter(smote);
   filtered.setClassifier(new NaiveBayes());
   filtered.buildClassifier(trainingDataset);
   var currEvaluation = new Evaluation(testingDataset);
   currEvaluation.evaluateModel(filtered, testingDataset);
}

trainingDatasettestingDataset 类型是 Instances 并且它们的值在每次迭代中适当变化。在第一次迭代中,没有出现问题,但在第二次迭代中出现了 java.lang.IllegalArgumentException: Comparison method violates its general contract!。异常堆栈跟踪是:

java.lang.IllegalArgumentException: Comparison method violates its general contract!
    at java.base/java.util.TimSort.mergeLo(TimSort.java:781)
    at java.base/java.util.TimSort.mergeAt(TimSort.java:518)
    at java.base/java.util.TimSort.mergeCollapse(TimSort.java:448)
    at java.base/java.util.TimSort.sort(TimSort.java:245)
    at java.base/java.util.Arrays.sort(Arrays.java:1441)
    at java.base/java.util.List.sort(List.java:506)
    at java.base/java.util.Collections.sort(Collections.java:179)
    at weka.filters.supervised.instance.SMOTE.doSMOTE(SMOTE.java:637)
    at weka.filters.supervised.instance.SMOTE.batchFinished(SMOTE.java:489)
    at weka.filters.Filter.useFilter(Filter.java:708)
    at weka.classifiers.meta.FilteredClassifier.setUp(FilteredClassifier.java:719)
    at weka.classifiers.meta.FilteredClassifier.buildClassifier(FilteredClassifier.java:794)

有谁知道如何解决这个问题?
提前致谢。

编辑:我忘了说我正在使用 java 11.0.11

编辑 2:根据@fracpete 的回答,我推断问题可能出在集合创建上。我声明我正在尝试预测另一个开源项目的 classes 的漏洞。因为Walk Forward,我有19个步骤,应该有19个不同的训练文件和19个测试文件。为了避免这种情况,我有一个 class InfoKeeper 的列表,它为每个步骤保留用于训练和测试的实例。在创建这个数组的过程中,我做了以下事情:

  1. 我从基础 ARFF 文件中创建了 2 个临时文件:训练测试文件保存版本 1 数据,测试集文件保存版本 2 数据。然后我阅读这些临时 ARFF 来创建实例 class。这些将由步骤 1 中的 InfoKeeper 相关人员保留。
  2. 我在训练集文件中附加了测试集文件的行(当然只有数据),这样它就会保留版本 1 和版本 2 的数据。然后我覆盖训练文件,让它保留版本 3 数据。我阅读了这些临时 ARFF,以获取将由 InfoKeeper 相关步骤 2 保留的实例。

代码在步骤 2 上迭代以创建所有剩余的 InfoKeeper。这个操作可能是问题吗?

我也尝试使用@frecpete 片段,但出现了同样的错误。我使用的文件如下:
training set file
testing set file

编辑 3:这就是我计算文件的方式:

public class FilesCreator {

    private File basicArff;
    private Instances totalData;

    private ArrayList<Instance> testingInstances;
    private File testingSet;
    private File trainingSet;

    /* *******************************************************************/

    public FilesCreator(File csvFile, File arffFile, File training, File testing) 
           throws IOException {
        var loader = new CSVLoader();
        loader.setSource(csvFile);
        this.totalData = loader.getDataSet(); // get instances object
        this.basicArff = arffFile;
        this.testingSet = testing;
        this.trainingSet = training;
    }

    private ArrayList<Attribute> getAttributesList(){
        var attributes = new ArrayList<Attribute>();
        int i;
        for (i = 0; i < this.totalData.numAttributes(); i++)
            attributes.add(this.totalData.attribute(i));
        return attributes;
    }

    private void writeHeader(PrintWriter pf) {
        // just write the attributes in the given file. 
        // f is either this.testingSet or this.trainingSet
        pf.append("@relation " + this.totalData.relationName() + "\n\n");
        pf.flush();
        var attributes = this.getAttributesList();
        for (Attribute line : attributes){
            pf.append(line.toString() + "\n");
            pf.flush();
        }
        pf.append("\n@data\n");
        pf.flush();
    }

    /* *******************************************************************/
    /* testing file */

    // testing instances
    private void computeTestingSet(int indexRelease){
        int i;
        int currIndex;
        // re-initialize the list
        this.testingInstances = new ArrayList<>();
        for (i = 0; i < this.totalData.numInstances(); i++){
            // first attribute is the release index
            currIndex = (int) this.totalData.instance(i).value(0);
            if (currIndex == indexRelease)
                testingInstances.add(this.totalData.instance(i));
            else if (currIndex > indexRelease)
                break;
        }
    }

    // testing file
    private void computeTestingFile(int indexRelease){
        this.computeTestingSet(indexRelease);
        try(var fp = new PrintWriter(this.testingSet)) {

            this.writeHeader(fp);
            for (Instance line : this.testingInstances){
                fp.append(line.toString() + "\n");
                fp.flush();
            }
        } catch (IOException e) {
            var logger = Logger.getLogger(FilesCreator.class.getName());
            logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
        }
    }
    /* *******************************************************************/
    // training file
    private void computeTrainingFile(int indexRelease){
        int i;

        try(var fw = new FileWriter(this.trainingSet, true);
            var fp = new PrintWriter(fw)) {
            if (indexRelease == 1) {
                // first iteration: need the header.
                fp.print("");
                fp.flush();
                this.writeHeader(fp);
                for (i = 0; i < this.totalData.numInstances(); i++) {
                    if ( (int) this.totalData.instance(i).value(0) > indexRelease)
                        break;
                    fp.append(this.totalData.instance(i).toString() + "\n");
                    fp.flush();
                }
            }
            else {
                // in this case just append the testing instances, which
                // are the indexReleas+1-th data:
                for (Instance obj : this.testingInstances){
                    fp.append(obj.toString() + "\n");
                    fp.flush();
                }
            }
        } catch (IOException e) {
            var logger = Logger.getLogger(FilesCreator.class.getName());
            logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
        }
    }

    /* *******************************************************************/
    // public method
    public void computeFiles(int indexRelease){
        this.computeTrainingFile(indexRelease);
        this.computeTestingFile(indexRelease + 1);
    }
}

最后一个 public 方法在另一个 class 的循环中被调用,从 1 到 19:

    FilesCreator filesCreator = new FilesCreator(csvFile, arffFile, training, testing);
    for (i = 1; i < 20; i++) {
        filesCreator.computeFiles(i);
        /* do something with files, such as getting Instances and
           use them for SMOTE computation */
    }

编辑 4:我通过执行以下操作从 FilesCreator 中的 totalData 中删除了重复的实例:

 var currDir = Paths.get(".").toAbsolutePath().normalize().toFile();
    var ext = ".arff";
    var tmpFile = File.createTempFile("without_replicated", ext, currDir);
    RemoveDuplicates.main(new String[]{"-i", this.basicArff.toPath().toString(), "-o", tmpFile.toPath().toString()});
    // output file has effective 0 instances repetitions 
    var arffLoader = new ArffLoader();
    arffLoader.setSource(tmpFile);
    this.totalData = arffLoader.getDataSet();
    Files.delete(tmpFile.toPath());

我无法手动修改它,因为它是先前计算的输出。该代码适用于迭代 2,但在迭代 3 中得到相同的错误。
此迭代的文件是:
train_iteration4.arff
test_iteration4.arff
这是上一个片段获得的非常完整的 arff 文件,它是由 arffLoader.setSource(tmpFile);:
加载的 full.arff

我整理了一个快速示例 Maven 项目,在 Java 11 (openjdk version "11.0.11" 2021-04-20):

下我没有遇到任何问题

范例pom.xml:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">

  <modelVersion>4.0.0</modelVersion>

  <groupId>com.github.fracpete</groupId>
  <artifactId>smote-test</artifactId>
  <packaging>jar</packaging>
  <version>0.0.1</version>

  <name>smote-test</name>

  <dependencies>
    <dependency>
       <groupId>nz.ac.waikato.cms.weka</groupId>
       <artifactId>weka-stable</artifactId>
       <version>3.8.5</version>
    </dependency>
    <dependency>
       <groupId>nz.ac.waikato.cms.weka</groupId>
       <artifactId>SMOTE</artifactId>
       <version>1.0.2</version>
    </dependency>
  </dependencies>

  <properties>
    <maven.compiler.source>11</maven.compiler.source>
    <maven.compiler.target>11</maven.compiler.target>
  </properties>

  <build>
    <plugins>
      <plugin>
        <groupId>org.apache.maven.plugins</groupId>
        <artifactId>maven-jar-plugin</artifactId>
        <version>2.3.2</version>
      </plugin>
      <plugin>
        <groupId>org.codehaus.mojo</groupId>
        <artifactId>exec-maven-plugin</artifactId>
        <version>3.0.0</version>
        <executions>
          <execution>
            <goals>
              <goal>exec</goal>
            </goals>
          </execution>
        </executions>
        <configuration>
          <executable>java</executable>
          <arguments>
            <argument>-classpath</argument>
            <classpath/>
            <argument>smote.SmoteTest</argument>
          </arguments>
        </configuration>
      </plugin>
    </plugins>
  </build>
</project>

我已经调整了你的 FilesCreator class 以使用 Java 8 语法(对我来说更容易阅读)并加载 ARFF 文件而不是 CSV 文件。这个 class 需要进入 src/main/java/smote:

package smote;

import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;

public class FilesCreator {

  private File basicArff;
  private Instances totalData;

  private ArrayList<Instance> testingInstances;
  private File testingSet;
  private File trainingSet;

  /* *******************************************************************/

  public FilesCreator(File fullArffFile, File arffFile, File training, File testing)
      throws IOException {
    ArffLoader loader = new ArffLoader();
    loader.setSource(fullArffFile);
    this.totalData = loader.getDataSet(); // get instances object
    this.basicArff = arffFile;
    this.testingSet = testing;
    this.trainingSet = training;
  }

  private ArrayList<Attribute> getAttributesList(){
    ArrayList<Attribute> attributes = new ArrayList<Attribute>();
    int i;
    for (i = 0; i < this.totalData.numAttributes(); i++)
      attributes.add(this.totalData.attribute(i));
    return attributes;
  }

  private void writeHeader(PrintWriter pf) {
    // just write the attributes in the given file.
    // f is either this.testingSet or this.trainingSet
    pf.append("@relation " + this.totalData.relationName() + "\n\n");
    pf.flush();
    ArrayList<Attribute> attributes = this.getAttributesList();
    for (Attribute line : attributes){
      pf.append(line.toString() + "\n");
      pf.flush();
    }
    pf.append("\n@data\n");
    pf.flush();
  }

  /* *******************************************************************/
  /* testing file */

  // testing instances
  private void computeTestingSet(int indexRelease){
    int i;
    int currIndex;
    // re-initialize the list
    this.testingInstances = new ArrayList<>();
    for (i = 0; i < this.totalData.numInstances(); i++){
      // first attribute is the release index
      currIndex = (int) this.totalData.instance(i).value(0);
      if (currIndex == indexRelease)
    testingInstances.add(this.totalData.instance(i));
      else if (currIndex > indexRelease)
    break;
    }
  }

  // testing file
  private void computeTestingFile(int indexRelease){
    this.computeTestingSet(indexRelease);
    try(PrintWriter fp = new PrintWriter(this.testingSet)) {

      this.writeHeader(fp);
      for (Instance line : this.testingInstances){
    fp.append(line.toString() + "\n");
    fp.flush();
      }
    } catch (IOException e) {
      Logger logger = Logger.getLogger(FilesCreator.class.getName());
      logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
    }
  }
  /* *******************************************************************/
  // training file
  private void computeTrainingFile(int indexRelease){
    int i;

    try(FileWriter fw = new FileWriter(this.trainingSet, true);
    PrintWriter fp = new PrintWriter(fw)) {
      if (indexRelease == 1) {
    // first iteration: need the header.
    fp.print("");
    fp.flush();
    this.writeHeader(fp);
    for (i = 0; i < this.totalData.numInstances(); i++) {
      if ( (int) this.totalData.instance(i).value(0) > indexRelease)
        break;
      fp.append(this.totalData.instance(i).toString() + "\n");
      fp.flush();
    }
      }
      else {
    // in this case just append the testing instances, which
    // are the indexReleas+1-th data:
    for (Instance obj : this.testingInstances){
      fp.append(obj.toString() + "\n");
      fp.flush();
    }
      }
    } catch (IOException e) {
      Logger logger = Logger.getLogger(FilesCreator.class.getName());
      logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
    }
  }

  /* *******************************************************************/
  // public method
  public void computeFiles(int indexRelease){
    this.computeTrainingFile(indexRelease);
    this.computeTestingFile(indexRelease + 1);
  }
}

然后我将您的两个示例 ARFF 文件合并为一个文件:full.arff

smote.SmoteTest class 位于 src/main/java/smote 中(您需要更改数据集的硬编码路径):

package smote;

import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.supervised.instance.SMOTE;

import java.io.File;

public class SmoteTest {

  public static void main(String[] args) throws Exception {
    FilesCreator filesCreator = new FilesCreator(new File("/home/fracpete/full.arff"), new File("."), new File("/home/fracpete/train.arff"), new File("/home/fracpete/test.arff"));
    for (int i = 1; i < 20; i++) {
      System.out.println("--> " + i);
      filesCreator.computeFiles(i);
      System.out.println("Reading train");
      Instances trainingDataset = DataSource.read("/home/fracpete/train.arff");
      trainingDataset.setClassIndex(trainingDataset.numAttributes() - 1);
      System.out.println("Reading test");
      Instances testingDataset = DataSource.read("/home/fracpete/test.arff");
      testingDataset.setClassIndex(trainingDataset.numAttributes() - 1);
      System.out.println("smote");
      FilteredClassifier filtered = new FilteredClassifier();
      SMOTE smote = new SMOTE();
      filtered.setFilter(smote);
      filtered.setClassifier(new NaiveBayes());
      filtered.buildClassifier(trainingDataset);
      Evaluation currEvaluation = new Evaluation(testingDataset);
      currEvaluation.evaluateModel(filtered, testingDataset);
      System.out.println(currEvaluation.toSummaryString());
    }
  }
}

然后编译运行它:

mvn clean install
mvn exec:exec 

第二次迭代也会失败。

但是,从组合数据集中删除所有重复行后,错误消失了。

我的感觉是 SMOTE 无法处理输入数据中的重复行。

注意:在任何 运行 之前,确保删除 train.arfftest.arff 文件。

我在 pom.xml 中解决了更改 smote 依赖项的问题:

<dependency>
   <groupId>nz.ac.waikato.cms.weka</groupId>
   <artifactId>SMOTE</artifactId>
   <version>1.0.3</version>
</dependency>

在这个版本中,我没有任何问题,我的代码按预期运行。希望这会对其他人有所帮助。