在 java 上使用 SMOTE 会引发 Comparison 方法违反其一般契约
Using SMOTE on java raises Comparison method violates its general contract
我正在 java 做一个项目,我需要使用 Weka 的 API。我使用 Maven 来管理依赖项,特别是,我有以下一个:
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>weka-stable</artifactId>
<version>3.8.5</version>
</dependency>
在这个版本中,SMOTEclass没有保留,但我确实需要它;这就是为什么我还在 pom.xml
中添加了以下依赖项:
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>SMOTE</artifactId>
<version>1.0.2</version>
</dependency>
在我的 Java 代码中,我还尝试开发 WalkForward
验证技术:我可以同时准备 训练集 和 每个步骤的测试集,所以我可以在循环中使用它们,我所做的如下:
for (...){
var filtered = new FilteredClassifier();
var smote = new SMOTE();
filtered.setFilter(smote);
filtered.setClassifier(new NaiveBayes());
filtered.buildClassifier(trainingDataset);
var currEvaluation = new Evaluation(testingDataset);
currEvaluation.evaluateModel(filtered, testingDataset);
}
trainingDataset
和 testingDataset
类型是 Instances
并且它们的值在每次迭代中适当变化。在第一次迭代中,没有出现问题,但在第二次迭代中出现了 java.lang.IllegalArgumentException: Comparison method violates its general contract!
。异常堆栈跟踪是:
java.lang.IllegalArgumentException: Comparison method violates its general contract!
at java.base/java.util.TimSort.mergeLo(TimSort.java:781)
at java.base/java.util.TimSort.mergeAt(TimSort.java:518)
at java.base/java.util.TimSort.mergeCollapse(TimSort.java:448)
at java.base/java.util.TimSort.sort(TimSort.java:245)
at java.base/java.util.Arrays.sort(Arrays.java:1441)
at java.base/java.util.List.sort(List.java:506)
at java.base/java.util.Collections.sort(Collections.java:179)
at weka.filters.supervised.instance.SMOTE.doSMOTE(SMOTE.java:637)
at weka.filters.supervised.instance.SMOTE.batchFinished(SMOTE.java:489)
at weka.filters.Filter.useFilter(Filter.java:708)
at weka.classifiers.meta.FilteredClassifier.setUp(FilteredClassifier.java:719)
at weka.classifiers.meta.FilteredClassifier.buildClassifier(FilteredClassifier.java:794)
有谁知道如何解决这个问题?
提前致谢。
编辑:我忘了说我正在使用 java 11.0.11
。
编辑 2:根据@fracpete 的回答,我推断问题可能出在集合创建上。我声明我正在尝试预测另一个开源项目的 classes 的漏洞。因为Walk Forward
,我有19个步骤,应该有19个不同的训练文件和19个测试文件。为了避免这种情况,我有一个 class InfoKeeper
的列表,它为每个步骤保留用于训练和测试的实例。在创建这个数组的过程中,我做了以下事情:
- 我从基础 ARFF 文件中创建了 2 个临时文件:训练测试文件保存版本 1 数据,测试集文件保存版本 2 数据。然后我阅读这些临时 ARFF 来创建实例 class。这些将由步骤 1 中的
InfoKeeper
相关人员保留。
- 我在训练集文件中附加了测试集文件的行(当然只有数据),这样它就会保留版本 1 和版本 2 的数据。然后我覆盖训练文件,让它保留版本 3 数据。我阅读了这些临时 ARFF,以获取将由
InfoKeeper
相关步骤 2 保留的实例。
代码在步骤 2 上迭代以创建所有剩余的 InfoKeeper
。这个操作可能是问题吗?
我也尝试使用@frecpete 片段,但出现了同样的错误。我使用的文件如下:
training set file
testing set file
编辑 3:这就是我计算文件的方式:
public class FilesCreator {
private File basicArff;
private Instances totalData;
private ArrayList<Instance> testingInstances;
private File testingSet;
private File trainingSet;
/* *******************************************************************/
public FilesCreator(File csvFile, File arffFile, File training, File testing)
throws IOException {
var loader = new CSVLoader();
loader.setSource(csvFile);
this.totalData = loader.getDataSet(); // get instances object
this.basicArff = arffFile;
this.testingSet = testing;
this.trainingSet = training;
}
private ArrayList<Attribute> getAttributesList(){
var attributes = new ArrayList<Attribute>();
int i;
for (i = 0; i < this.totalData.numAttributes(); i++)
attributes.add(this.totalData.attribute(i));
return attributes;
}
private void writeHeader(PrintWriter pf) {
// just write the attributes in the given file.
// f is either this.testingSet or this.trainingSet
pf.append("@relation " + this.totalData.relationName() + "\n\n");
pf.flush();
var attributes = this.getAttributesList();
for (Attribute line : attributes){
pf.append(line.toString() + "\n");
pf.flush();
}
pf.append("\n@data\n");
pf.flush();
}
/* *******************************************************************/
/* testing file */
// testing instances
private void computeTestingSet(int indexRelease){
int i;
int currIndex;
// re-initialize the list
this.testingInstances = new ArrayList<>();
for (i = 0; i < this.totalData.numInstances(); i++){
// first attribute is the release index
currIndex = (int) this.totalData.instance(i).value(0);
if (currIndex == indexRelease)
testingInstances.add(this.totalData.instance(i));
else if (currIndex > indexRelease)
break;
}
}
// testing file
private void computeTestingFile(int indexRelease){
this.computeTestingSet(indexRelease);
try(var fp = new PrintWriter(this.testingSet)) {
this.writeHeader(fp);
for (Instance line : this.testingInstances){
fp.append(line.toString() + "\n");
fp.flush();
}
} catch (IOException e) {
var logger = Logger.getLogger(FilesCreator.class.getName());
logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
}
}
/* *******************************************************************/
// training file
private void computeTrainingFile(int indexRelease){
int i;
try(var fw = new FileWriter(this.trainingSet, true);
var fp = new PrintWriter(fw)) {
if (indexRelease == 1) {
// first iteration: need the header.
fp.print("");
fp.flush();
this.writeHeader(fp);
for (i = 0; i < this.totalData.numInstances(); i++) {
if ( (int) this.totalData.instance(i).value(0) > indexRelease)
break;
fp.append(this.totalData.instance(i).toString() + "\n");
fp.flush();
}
}
else {
// in this case just append the testing instances, which
// are the indexReleas+1-th data:
for (Instance obj : this.testingInstances){
fp.append(obj.toString() + "\n");
fp.flush();
}
}
} catch (IOException e) {
var logger = Logger.getLogger(FilesCreator.class.getName());
logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
}
}
/* *******************************************************************/
// public method
public void computeFiles(int indexRelease){
this.computeTrainingFile(indexRelease);
this.computeTestingFile(indexRelease + 1);
}
}
最后一个 public 方法在另一个 class 的循环中被调用,从 1 到 19:
FilesCreator filesCreator = new FilesCreator(csvFile, arffFile, training, testing);
for (i = 1; i < 20; i++) {
filesCreator.computeFiles(i);
/* do something with files, such as getting Instances and
use them for SMOTE computation */
}
编辑 4:我通过执行以下操作从 FilesCreator
中的 totalData
中删除了重复的实例:
var currDir = Paths.get(".").toAbsolutePath().normalize().toFile();
var ext = ".arff";
var tmpFile = File.createTempFile("without_replicated", ext, currDir);
RemoveDuplicates.main(new String[]{"-i", this.basicArff.toPath().toString(), "-o", tmpFile.toPath().toString()});
// output file has effective 0 instances repetitions
var arffLoader = new ArffLoader();
arffLoader.setSource(tmpFile);
this.totalData = arffLoader.getDataSet();
Files.delete(tmpFile.toPath());
我无法手动修改它,因为它是先前计算的输出。该代码适用于迭代 2
,但在迭代 3
中得到相同的错误。
此迭代的文件是:
train_iteration4.arff
test_iteration4.arff
这是上一个片段获得的非常完整的 arff 文件,它是由 arffLoader.setSource(tmpFile);
:
加载的
full.arff
我整理了一个快速示例 Maven 项目,在 Java 11 (openjdk version "11.0.11" 2021-04-20
):
下我没有遇到任何问题
范例pom.xml
:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.github.fracpete</groupId>
<artifactId>smote-test</artifactId>
<packaging>jar</packaging>
<version>0.0.1</version>
<name>smote-test</name>
<dependencies>
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>weka-stable</artifactId>
<version>3.8.5</version>
</dependency>
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>SMOTE</artifactId>
<version>1.0.2</version>
</dependency>
</dependencies>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
</properties>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>2.3.2</version>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>3.0.0</version>
<executions>
<execution>
<goals>
<goal>exec</goal>
</goals>
</execution>
</executions>
<configuration>
<executable>java</executable>
<arguments>
<argument>-classpath</argument>
<classpath/>
<argument>smote.SmoteTest</argument>
</arguments>
</configuration>
</plugin>
</plugins>
</build>
</project>
我已经调整了你的 FilesCreator
class 以使用 Java 8 语法(对我来说更容易阅读)并加载 ARFF 文件而不是 CSV 文件。这个 class 需要进入 src/main/java/smote
:
package smote;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;
public class FilesCreator {
private File basicArff;
private Instances totalData;
private ArrayList<Instance> testingInstances;
private File testingSet;
private File trainingSet;
/* *******************************************************************/
public FilesCreator(File fullArffFile, File arffFile, File training, File testing)
throws IOException {
ArffLoader loader = new ArffLoader();
loader.setSource(fullArffFile);
this.totalData = loader.getDataSet(); // get instances object
this.basicArff = arffFile;
this.testingSet = testing;
this.trainingSet = training;
}
private ArrayList<Attribute> getAttributesList(){
ArrayList<Attribute> attributes = new ArrayList<Attribute>();
int i;
for (i = 0; i < this.totalData.numAttributes(); i++)
attributes.add(this.totalData.attribute(i));
return attributes;
}
private void writeHeader(PrintWriter pf) {
// just write the attributes in the given file.
// f is either this.testingSet or this.trainingSet
pf.append("@relation " + this.totalData.relationName() + "\n\n");
pf.flush();
ArrayList<Attribute> attributes = this.getAttributesList();
for (Attribute line : attributes){
pf.append(line.toString() + "\n");
pf.flush();
}
pf.append("\n@data\n");
pf.flush();
}
/* *******************************************************************/
/* testing file */
// testing instances
private void computeTestingSet(int indexRelease){
int i;
int currIndex;
// re-initialize the list
this.testingInstances = new ArrayList<>();
for (i = 0; i < this.totalData.numInstances(); i++){
// first attribute is the release index
currIndex = (int) this.totalData.instance(i).value(0);
if (currIndex == indexRelease)
testingInstances.add(this.totalData.instance(i));
else if (currIndex > indexRelease)
break;
}
}
// testing file
private void computeTestingFile(int indexRelease){
this.computeTestingSet(indexRelease);
try(PrintWriter fp = new PrintWriter(this.testingSet)) {
this.writeHeader(fp);
for (Instance line : this.testingInstances){
fp.append(line.toString() + "\n");
fp.flush();
}
} catch (IOException e) {
Logger logger = Logger.getLogger(FilesCreator.class.getName());
logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
}
}
/* *******************************************************************/
// training file
private void computeTrainingFile(int indexRelease){
int i;
try(FileWriter fw = new FileWriter(this.trainingSet, true);
PrintWriter fp = new PrintWriter(fw)) {
if (indexRelease == 1) {
// first iteration: need the header.
fp.print("");
fp.flush();
this.writeHeader(fp);
for (i = 0; i < this.totalData.numInstances(); i++) {
if ( (int) this.totalData.instance(i).value(0) > indexRelease)
break;
fp.append(this.totalData.instance(i).toString() + "\n");
fp.flush();
}
}
else {
// in this case just append the testing instances, which
// are the indexReleas+1-th data:
for (Instance obj : this.testingInstances){
fp.append(obj.toString() + "\n");
fp.flush();
}
}
} catch (IOException e) {
Logger logger = Logger.getLogger(FilesCreator.class.getName());
logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
}
}
/* *******************************************************************/
// public method
public void computeFiles(int indexRelease){
this.computeTrainingFile(indexRelease);
this.computeTestingFile(indexRelease + 1);
}
}
然后我将您的两个示例 ARFF 文件合并为一个文件:full.arff
smote.SmoteTest
class 位于 src/main/java/smote
中(您需要更改数据集的硬编码路径):
package smote;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.supervised.instance.SMOTE;
import java.io.File;
public class SmoteTest {
public static void main(String[] args) throws Exception {
FilesCreator filesCreator = new FilesCreator(new File("/home/fracpete/full.arff"), new File("."), new File("/home/fracpete/train.arff"), new File("/home/fracpete/test.arff"));
for (int i = 1; i < 20; i++) {
System.out.println("--> " + i);
filesCreator.computeFiles(i);
System.out.println("Reading train");
Instances trainingDataset = DataSource.read("/home/fracpete/train.arff");
trainingDataset.setClassIndex(trainingDataset.numAttributes() - 1);
System.out.println("Reading test");
Instances testingDataset = DataSource.read("/home/fracpete/test.arff");
testingDataset.setClassIndex(trainingDataset.numAttributes() - 1);
System.out.println("smote");
FilteredClassifier filtered = new FilteredClassifier();
SMOTE smote = new SMOTE();
filtered.setFilter(smote);
filtered.setClassifier(new NaiveBayes());
filtered.buildClassifier(trainingDataset);
Evaluation currEvaluation = new Evaluation(testingDataset);
currEvaluation.evaluateModel(filtered, testingDataset);
System.out.println(currEvaluation.toSummaryString());
}
}
}
然后编译运行它:
mvn clean install
mvn exec:exec
第二次迭代也会失败。
但是,从组合数据集中删除所有重复行后,错误消失了。
我的感觉是 SMOTE 无法处理输入数据中的重复行。
注意:在任何 运行 之前,确保删除 train.arff
和 test.arff
文件。
我在 pom.xml
中解决了更改 smote 依赖项的问题:
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>SMOTE</artifactId>
<version>1.0.3</version>
</dependency>
在这个版本中,我没有任何问题,我的代码按预期运行。希望这会对其他人有所帮助。
我正在 java 做一个项目,我需要使用 Weka 的 API。我使用 Maven 来管理依赖项,特别是,我有以下一个:
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>weka-stable</artifactId>
<version>3.8.5</version>
</dependency>
在这个版本中,SMOTEclass没有保留,但我确实需要它;这就是为什么我还在 pom.xml
中添加了以下依赖项:
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>SMOTE</artifactId>
<version>1.0.2</version>
</dependency>
在我的 Java 代码中,我还尝试开发 WalkForward
验证技术:我可以同时准备 训练集 和 每个步骤的测试集,所以我可以在循环中使用它们,我所做的如下:
for (...){
var filtered = new FilteredClassifier();
var smote = new SMOTE();
filtered.setFilter(smote);
filtered.setClassifier(new NaiveBayes());
filtered.buildClassifier(trainingDataset);
var currEvaluation = new Evaluation(testingDataset);
currEvaluation.evaluateModel(filtered, testingDataset);
}
trainingDataset
和 testingDataset
类型是 Instances
并且它们的值在每次迭代中适当变化。在第一次迭代中,没有出现问题,但在第二次迭代中出现了 java.lang.IllegalArgumentException: Comparison method violates its general contract!
。异常堆栈跟踪是:
java.lang.IllegalArgumentException: Comparison method violates its general contract!
at java.base/java.util.TimSort.mergeLo(TimSort.java:781)
at java.base/java.util.TimSort.mergeAt(TimSort.java:518)
at java.base/java.util.TimSort.mergeCollapse(TimSort.java:448)
at java.base/java.util.TimSort.sort(TimSort.java:245)
at java.base/java.util.Arrays.sort(Arrays.java:1441)
at java.base/java.util.List.sort(List.java:506)
at java.base/java.util.Collections.sort(Collections.java:179)
at weka.filters.supervised.instance.SMOTE.doSMOTE(SMOTE.java:637)
at weka.filters.supervised.instance.SMOTE.batchFinished(SMOTE.java:489)
at weka.filters.Filter.useFilter(Filter.java:708)
at weka.classifiers.meta.FilteredClassifier.setUp(FilteredClassifier.java:719)
at weka.classifiers.meta.FilteredClassifier.buildClassifier(FilteredClassifier.java:794)
有谁知道如何解决这个问题?
提前致谢。
编辑:我忘了说我正在使用 java 11.0.11
。
编辑 2:根据@fracpete 的回答,我推断问题可能出在集合创建上。我声明我正在尝试预测另一个开源项目的 classes 的漏洞。因为Walk Forward
,我有19个步骤,应该有19个不同的训练文件和19个测试文件。为了避免这种情况,我有一个 class InfoKeeper
的列表,它为每个步骤保留用于训练和测试的实例。在创建这个数组的过程中,我做了以下事情:
- 我从基础 ARFF 文件中创建了 2 个临时文件:训练测试文件保存版本 1 数据,测试集文件保存版本 2 数据。然后我阅读这些临时 ARFF 来创建实例 class。这些将由步骤 1 中的
InfoKeeper
相关人员保留。 - 我在训练集文件中附加了测试集文件的行(当然只有数据),这样它就会保留版本 1 和版本 2 的数据。然后我覆盖训练文件,让它保留版本 3 数据。我阅读了这些临时 ARFF,以获取将由
InfoKeeper
相关步骤 2 保留的实例。
代码在步骤 2 上迭代以创建所有剩余的 InfoKeeper
。这个操作可能是问题吗?
我也尝试使用@frecpete 片段,但出现了同样的错误。我使用的文件如下:
training set file
testing set file
编辑 3:这就是我计算文件的方式:
public class FilesCreator {
private File basicArff;
private Instances totalData;
private ArrayList<Instance> testingInstances;
private File testingSet;
private File trainingSet;
/* *******************************************************************/
public FilesCreator(File csvFile, File arffFile, File training, File testing)
throws IOException {
var loader = new CSVLoader();
loader.setSource(csvFile);
this.totalData = loader.getDataSet(); // get instances object
this.basicArff = arffFile;
this.testingSet = testing;
this.trainingSet = training;
}
private ArrayList<Attribute> getAttributesList(){
var attributes = new ArrayList<Attribute>();
int i;
for (i = 0; i < this.totalData.numAttributes(); i++)
attributes.add(this.totalData.attribute(i));
return attributes;
}
private void writeHeader(PrintWriter pf) {
// just write the attributes in the given file.
// f is either this.testingSet or this.trainingSet
pf.append("@relation " + this.totalData.relationName() + "\n\n");
pf.flush();
var attributes = this.getAttributesList();
for (Attribute line : attributes){
pf.append(line.toString() + "\n");
pf.flush();
}
pf.append("\n@data\n");
pf.flush();
}
/* *******************************************************************/
/* testing file */
// testing instances
private void computeTestingSet(int indexRelease){
int i;
int currIndex;
// re-initialize the list
this.testingInstances = new ArrayList<>();
for (i = 0; i < this.totalData.numInstances(); i++){
// first attribute is the release index
currIndex = (int) this.totalData.instance(i).value(0);
if (currIndex == indexRelease)
testingInstances.add(this.totalData.instance(i));
else if (currIndex > indexRelease)
break;
}
}
// testing file
private void computeTestingFile(int indexRelease){
this.computeTestingSet(indexRelease);
try(var fp = new PrintWriter(this.testingSet)) {
this.writeHeader(fp);
for (Instance line : this.testingInstances){
fp.append(line.toString() + "\n");
fp.flush();
}
} catch (IOException e) {
var logger = Logger.getLogger(FilesCreator.class.getName());
logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
}
}
/* *******************************************************************/
// training file
private void computeTrainingFile(int indexRelease){
int i;
try(var fw = new FileWriter(this.trainingSet, true);
var fp = new PrintWriter(fw)) {
if (indexRelease == 1) {
// first iteration: need the header.
fp.print("");
fp.flush();
this.writeHeader(fp);
for (i = 0; i < this.totalData.numInstances(); i++) {
if ( (int) this.totalData.instance(i).value(0) > indexRelease)
break;
fp.append(this.totalData.instance(i).toString() + "\n");
fp.flush();
}
}
else {
// in this case just append the testing instances, which
// are the indexReleas+1-th data:
for (Instance obj : this.testingInstances){
fp.append(obj.toString() + "\n");
fp.flush();
}
}
} catch (IOException e) {
var logger = Logger.getLogger(FilesCreator.class.getName());
logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
}
}
/* *******************************************************************/
// public method
public void computeFiles(int indexRelease){
this.computeTrainingFile(indexRelease);
this.computeTestingFile(indexRelease + 1);
}
}
最后一个 public 方法在另一个 class 的循环中被调用,从 1 到 19:
FilesCreator filesCreator = new FilesCreator(csvFile, arffFile, training, testing);
for (i = 1; i < 20; i++) {
filesCreator.computeFiles(i);
/* do something with files, such as getting Instances and
use them for SMOTE computation */
}
编辑 4:我通过执行以下操作从 FilesCreator
中的 totalData
中删除了重复的实例:
var currDir = Paths.get(".").toAbsolutePath().normalize().toFile();
var ext = ".arff";
var tmpFile = File.createTempFile("without_replicated", ext, currDir);
RemoveDuplicates.main(new String[]{"-i", this.basicArff.toPath().toString(), "-o", tmpFile.toPath().toString()});
// output file has effective 0 instances repetitions
var arffLoader = new ArffLoader();
arffLoader.setSource(tmpFile);
this.totalData = arffLoader.getDataSet();
Files.delete(tmpFile.toPath());
我无法手动修改它,因为它是先前计算的输出。该代码适用于迭代 2
,但在迭代 3
中得到相同的错误。
此迭代的文件是:
train_iteration4.arff
test_iteration4.arff
这是上一个片段获得的非常完整的 arff 文件,它是由 arffLoader.setSource(tmpFile);
:
加载的
full.arff
我整理了一个快速示例 Maven 项目,在 Java 11 (openjdk version "11.0.11" 2021-04-20
):
范例pom.xml
:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.github.fracpete</groupId>
<artifactId>smote-test</artifactId>
<packaging>jar</packaging>
<version>0.0.1</version>
<name>smote-test</name>
<dependencies>
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>weka-stable</artifactId>
<version>3.8.5</version>
</dependency>
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>SMOTE</artifactId>
<version>1.0.2</version>
</dependency>
</dependencies>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
</properties>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>2.3.2</version>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>3.0.0</version>
<executions>
<execution>
<goals>
<goal>exec</goal>
</goals>
</execution>
</executions>
<configuration>
<executable>java</executable>
<arguments>
<argument>-classpath</argument>
<classpath/>
<argument>smote.SmoteTest</argument>
</arguments>
</configuration>
</plugin>
</plugins>
</build>
</project>
我已经调整了你的 FilesCreator
class 以使用 Java 8 语法(对我来说更容易阅读)并加载 ARFF 文件而不是 CSV 文件。这个 class 需要进入 src/main/java/smote
:
package smote;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;
public class FilesCreator {
private File basicArff;
private Instances totalData;
private ArrayList<Instance> testingInstances;
private File testingSet;
private File trainingSet;
/* *******************************************************************/
public FilesCreator(File fullArffFile, File arffFile, File training, File testing)
throws IOException {
ArffLoader loader = new ArffLoader();
loader.setSource(fullArffFile);
this.totalData = loader.getDataSet(); // get instances object
this.basicArff = arffFile;
this.testingSet = testing;
this.trainingSet = training;
}
private ArrayList<Attribute> getAttributesList(){
ArrayList<Attribute> attributes = new ArrayList<Attribute>();
int i;
for (i = 0; i < this.totalData.numAttributes(); i++)
attributes.add(this.totalData.attribute(i));
return attributes;
}
private void writeHeader(PrintWriter pf) {
// just write the attributes in the given file.
// f is either this.testingSet or this.trainingSet
pf.append("@relation " + this.totalData.relationName() + "\n\n");
pf.flush();
ArrayList<Attribute> attributes = this.getAttributesList();
for (Attribute line : attributes){
pf.append(line.toString() + "\n");
pf.flush();
}
pf.append("\n@data\n");
pf.flush();
}
/* *******************************************************************/
/* testing file */
// testing instances
private void computeTestingSet(int indexRelease){
int i;
int currIndex;
// re-initialize the list
this.testingInstances = new ArrayList<>();
for (i = 0; i < this.totalData.numInstances(); i++){
// first attribute is the release index
currIndex = (int) this.totalData.instance(i).value(0);
if (currIndex == indexRelease)
testingInstances.add(this.totalData.instance(i));
else if (currIndex > indexRelease)
break;
}
}
// testing file
private void computeTestingFile(int indexRelease){
this.computeTestingSet(indexRelease);
try(PrintWriter fp = new PrintWriter(this.testingSet)) {
this.writeHeader(fp);
for (Instance line : this.testingInstances){
fp.append(line.toString() + "\n");
fp.flush();
}
} catch (IOException e) {
Logger logger = Logger.getLogger(FilesCreator.class.getName());
logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
}
}
/* *******************************************************************/
// training file
private void computeTrainingFile(int indexRelease){
int i;
try(FileWriter fw = new FileWriter(this.trainingSet, true);
PrintWriter fp = new PrintWriter(fw)) {
if (indexRelease == 1) {
// first iteration: need the header.
fp.print("");
fp.flush();
this.writeHeader(fp);
for (i = 0; i < this.totalData.numInstances(); i++) {
if ( (int) this.totalData.instance(i).value(0) > indexRelease)
break;
fp.append(this.totalData.instance(i).toString() + "\n");
fp.flush();
}
}
else {
// in this case just append the testing instances, which
// are the indexReleas+1-th data:
for (Instance obj : this.testingInstances){
fp.append(obj.toString() + "\n");
fp.flush();
}
}
} catch (IOException e) {
Logger logger = Logger.getLogger(FilesCreator.class.getName());
logger.log(Level.OFF, Arrays.toString(e.getStackTrace()));
}
}
/* *******************************************************************/
// public method
public void computeFiles(int indexRelease){
this.computeTrainingFile(indexRelease);
this.computeTestingFile(indexRelease + 1);
}
}
然后我将您的两个示例 ARFF 文件合并为一个文件:full.arff
smote.SmoteTest
class 位于 src/main/java/smote
中(您需要更改数据集的硬编码路径):
package smote;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.supervised.instance.SMOTE;
import java.io.File;
public class SmoteTest {
public static void main(String[] args) throws Exception {
FilesCreator filesCreator = new FilesCreator(new File("/home/fracpete/full.arff"), new File("."), new File("/home/fracpete/train.arff"), new File("/home/fracpete/test.arff"));
for (int i = 1; i < 20; i++) {
System.out.println("--> " + i);
filesCreator.computeFiles(i);
System.out.println("Reading train");
Instances trainingDataset = DataSource.read("/home/fracpete/train.arff");
trainingDataset.setClassIndex(trainingDataset.numAttributes() - 1);
System.out.println("Reading test");
Instances testingDataset = DataSource.read("/home/fracpete/test.arff");
testingDataset.setClassIndex(trainingDataset.numAttributes() - 1);
System.out.println("smote");
FilteredClassifier filtered = new FilteredClassifier();
SMOTE smote = new SMOTE();
filtered.setFilter(smote);
filtered.setClassifier(new NaiveBayes());
filtered.buildClassifier(trainingDataset);
Evaluation currEvaluation = new Evaluation(testingDataset);
currEvaluation.evaluateModel(filtered, testingDataset);
System.out.println(currEvaluation.toSummaryString());
}
}
}
然后编译运行它:
mvn clean install
mvn exec:exec
第二次迭代也会失败。
但是,从组合数据集中删除所有重复行后,错误消失了。
我的感觉是 SMOTE 无法处理输入数据中的重复行。
注意:在任何 运行 之前,确保删除 train.arff
和 test.arff
文件。
我在 pom.xml
中解决了更改 smote 依赖项的问题:
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>SMOTE</artifactId>
<version>1.0.3</version>
</dependency>
在这个版本中,我没有任何问题,我的代码按预期运行。希望这会对其他人有所帮助。