关于生成决策树的 StackOverflowError JAVA
StackOverflowError on decision tree generating JAVA
我正在尝试编写生成决策树的 ID3 算法,但是当我 运行 我的代码时出现 WhosebugError。
调试时我注意到当属性下降到 4(从最初的 9)时循环开始。
生成树的代码如下。我调用的所有函数都正常工作,它们已经过测试。
但是,错误代码指出问题出在另一个使用流的函数中,但它已被单独测试
我知道它工作正常。请记住,我正在处理随机数据,因此该函数有时会抛出
错误,有时不会。我post它下面的错误代码,但是熵函数和信息增益工作。
这是树节点结构:
public class TreeNode {
List<Patient> samples;
List<TreeNode> children;
TreeNode parent;
Integer attribute;
String attributeValue;
String className;
public TreeNode(List<Patient> samples, List<TreeNode> children, TreeNode parent, Integer attribute,
String attributeValue, String className) {
this.samples = samples;
this.children = children;
this.parent = parent;
this.attribute = attribute;
this.attributeValue = attributeValue;
this.className = className;
}
}
这就是引发错误的代码:
public TreeNode id3(List<Patient> patients, List<Integer> attributes, TreeNode root) {
boolean isLeaf = patients.stream().collect(Collectors.groupingBy(i -> i.className)).keySet().size() == 1;
if (isLeaf) {
root.setClassName(patients.get(0).className);
return root;
}
if (attributes.size() == 0) {
root.setClassName(mostCommonClass(patients));
return root;
}
int bestAttribute = maxInformationGainAttribute(patients, attributes);
Set<String> attributeValues = attributeValues(patients, bestAttribute);
for (String value : attributeValues) {
List<Patient> branch = patients.stream().filter(i -> i.patientData[bestAttribute].equals(value))
.collect(Collectors.toList());
TreeNode child = new TreeNode(branch, new ArrayList<>(), root, bestAttribute, value, null);
if (branch.isEmpty()) {
child.setClassName(mostCommonClass(patients));
root.addChild(new TreeNode(child));
} else {
List<Integer> newAttributes = new ArrayList<>();
newAttributes.addAll(attributes);
newAttributes.remove(new Integer(bestAttribute));
root.addChild(new TreeNode(id3(branch, newAttributes, child)));
}
}
return root;
}
这些是其他功能:
public static double entropy(List<Patient> patients) {
double entropy = 0.0;
double recurP = (double) patients.stream().filter(i -> i.className.equals("recurrence-events")).count()
/ (double) patients.size();
double noRecurP = (double) patients.stream().filter(i -> i.className.equals("no-recurrence-events")).count()
/ (double) patients.size();
entropy -= (recurP * (recurP > 0 ? Math.log(recurP) : 0 / Math.log(2))
+ noRecurP * (noRecurP > 0 ? Math.log(noRecurP) : 0 / Math.log(2)));
return entropy;
}
public static double informationGain(List<Patient> patients, int attribute) {
double informationGain = entropy(patients);
Map<String, List<Patient>> patientsGroupedByAttribute = patients.stream()
.collect(Collectors.groupingBy(i -> i.patientData[attribute]));
List<List<Patient>> subsets = new ArrayList<>();
for (String i : patientsGroupedByAttribute.keySet()) {
subsets.add(patientsGroupedByAttribute.get(i));
}
for (List<Patient> lp : subsets) {
informationGain -= proportion(lp, patients) * entropy(lp);
}
return informationGain;
}
private static int maxInformationGainAttribute(List<Patient> patients, List<Integer> attributes) {
int maxAttribute = 0;
double maxInformationGain = 0;
for (int i : attributes) {
if (informationGain(patients, i) > maxInformationGain) {
maxAttribute = i;
maxInformationGain = informationGain(patients, i);
}
}
return maxAttribute;
}
例外情况:
Exception in thread "main" java.lang.WhosebugError
at java.util.stream.ReferencePipeline.accept(Unknown Source)
at java.util.ArrayList$ArrayListSpliterator.forEachRemaining(Unknown Source)
at java.util.stream.AbstractPipeline.copyInto(Unknown Source)
at java.util.stream.AbstractPipeline.wrapAndCopyInto(Unknown Source)
at java.util.stream.ReduceOps$ReduceOp.evaluateSequential(Unknown Source)
at java.util.stream.AbstractPipeline.evaluate(Unknown Source)
at java.util.stream.LongPipeline.reduce(Unknown Source)
at java.util.stream.LongPipeline.sum(Unknown Source)
at java.util.stream.ReferencePipeline.count(Unknown Source)
at Patient.entropy(Patient.java:39)
at Patient.informationGain(Patient.java:67)
at Patient.maxInformationGainAttribute(Patient.java:85)
at Patient.id3(Patient.java:109)
行:
root.addChild(new TreeNode(id3(branch, newAttributes, child)));
每次方法递归都在调用,导致栈溢出。这告诉我您的逻辑有问题,其中 "base cases" 的 none 结束了递归,即 return root,正在被访问。我对所需的行为或起始数据了解不多,无法查明出了什么问题,但我会首先使用调试器单步执行代码,并确保方法中的逻辑符合您的预期。我知道这不是一个很好的答案,但它是一个起点,希望能有所帮助,或者其他人会提出更具体的解决方案。
我正在尝试编写生成决策树的 ID3 算法,但是当我 运行 我的代码时出现 WhosebugError。 调试时我注意到当属性下降到 4(从最初的 9)时循环开始。 生成树的代码如下。我调用的所有函数都正常工作,它们已经过测试。 但是,错误代码指出问题出在另一个使用流的函数中,但它已被单独测试 我知道它工作正常。请记住,我正在处理随机数据,因此该函数有时会抛出 错误,有时不会。我post它下面的错误代码,但是熵函数和信息增益工作。
这是树节点结构:
public class TreeNode {
List<Patient> samples;
List<TreeNode> children;
TreeNode parent;
Integer attribute;
String attributeValue;
String className;
public TreeNode(List<Patient> samples, List<TreeNode> children, TreeNode parent, Integer attribute,
String attributeValue, String className) {
this.samples = samples;
this.children = children;
this.parent = parent;
this.attribute = attribute;
this.attributeValue = attributeValue;
this.className = className;
}
}
这就是引发错误的代码:
public TreeNode id3(List<Patient> patients, List<Integer> attributes, TreeNode root) {
boolean isLeaf = patients.stream().collect(Collectors.groupingBy(i -> i.className)).keySet().size() == 1;
if (isLeaf) {
root.setClassName(patients.get(0).className);
return root;
}
if (attributes.size() == 0) {
root.setClassName(mostCommonClass(patients));
return root;
}
int bestAttribute = maxInformationGainAttribute(patients, attributes);
Set<String> attributeValues = attributeValues(patients, bestAttribute);
for (String value : attributeValues) {
List<Patient> branch = patients.stream().filter(i -> i.patientData[bestAttribute].equals(value))
.collect(Collectors.toList());
TreeNode child = new TreeNode(branch, new ArrayList<>(), root, bestAttribute, value, null);
if (branch.isEmpty()) {
child.setClassName(mostCommonClass(patients));
root.addChild(new TreeNode(child));
} else {
List<Integer> newAttributes = new ArrayList<>();
newAttributes.addAll(attributes);
newAttributes.remove(new Integer(bestAttribute));
root.addChild(new TreeNode(id3(branch, newAttributes, child)));
}
}
return root;
}
这些是其他功能:
public static double entropy(List<Patient> patients) {
double entropy = 0.0;
double recurP = (double) patients.stream().filter(i -> i.className.equals("recurrence-events")).count()
/ (double) patients.size();
double noRecurP = (double) patients.stream().filter(i -> i.className.equals("no-recurrence-events")).count()
/ (double) patients.size();
entropy -= (recurP * (recurP > 0 ? Math.log(recurP) : 0 / Math.log(2))
+ noRecurP * (noRecurP > 0 ? Math.log(noRecurP) : 0 / Math.log(2)));
return entropy;
}
public static double informationGain(List<Patient> patients, int attribute) {
double informationGain = entropy(patients);
Map<String, List<Patient>> patientsGroupedByAttribute = patients.stream()
.collect(Collectors.groupingBy(i -> i.patientData[attribute]));
List<List<Patient>> subsets = new ArrayList<>();
for (String i : patientsGroupedByAttribute.keySet()) {
subsets.add(patientsGroupedByAttribute.get(i));
}
for (List<Patient> lp : subsets) {
informationGain -= proportion(lp, patients) * entropy(lp);
}
return informationGain;
}
private static int maxInformationGainAttribute(List<Patient> patients, List<Integer> attributes) {
int maxAttribute = 0;
double maxInformationGain = 0;
for (int i : attributes) {
if (informationGain(patients, i) > maxInformationGain) {
maxAttribute = i;
maxInformationGain = informationGain(patients, i);
}
}
return maxAttribute;
}
例外情况:
Exception in thread "main" java.lang.WhosebugError
at java.util.stream.ReferencePipeline.accept(Unknown Source)
at java.util.ArrayList$ArrayListSpliterator.forEachRemaining(Unknown Source)
at java.util.stream.AbstractPipeline.copyInto(Unknown Source)
at java.util.stream.AbstractPipeline.wrapAndCopyInto(Unknown Source)
at java.util.stream.ReduceOps$ReduceOp.evaluateSequential(Unknown Source)
at java.util.stream.AbstractPipeline.evaluate(Unknown Source)
at java.util.stream.LongPipeline.reduce(Unknown Source)
at java.util.stream.LongPipeline.sum(Unknown Source)
at java.util.stream.ReferencePipeline.count(Unknown Source)
at Patient.entropy(Patient.java:39)
at Patient.informationGain(Patient.java:67)
at Patient.maxInformationGainAttribute(Patient.java:85)
at Patient.id3(Patient.java:109)
行:
root.addChild(new TreeNode(id3(branch, newAttributes, child)));
每次方法递归都在调用,导致栈溢出。这告诉我您的逻辑有问题,其中 "base cases" 的 none 结束了递归,即 return root,正在被访问。我对所需的行为或起始数据了解不多,无法查明出了什么问题,但我会首先使用调试器单步执行代码,并确保方法中的逻辑符合您的预期。我知道这不是一个很好的答案,但它是一个起点,希望能有所帮助,或者其他人会提出更具体的解决方案。