发布 kNN 算法的加载数据集 - Java
Issue loading data sets for kNN algorithm - Java
我已将 KNN 算法应用于 class 化手写数字。这些数字最初是 8*8 的矢量格式,并被拉伸形成一个 1*64 的矢量,每组数据的 class 代码为 0..9。
据我所知,我的代码在理论上应该可以工作,但这是我第一次试验这种算法。当我尝试通过我的算法输入我的数据集时,我的问题源于我在代码中突出显示的行上抛出了一个错误。可以找到训练数据集here and the validation set here。如果有帮助,我也留在了我以前的工作主要功能中。
ImageMatrix.java
import java.util.*;
public class ImageMatrix {
private int[] data;
private int classCode;
public ImageMatrix(int[] data, int classCode) {
assert data.length == 64; //maximum array length of 64
this.data = data;
this.classCode = classCode;
}
public String toString() {
return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable
}
public int[] getData() {
return data;
}
public int getClassCode() {
return classCode;
}
}
ImageMatrixDB.java
import java.util.*;
import java.io.*;
public class ImageMatrixDB implements Iterable<ImageMatrix> {
private List<ImageMatrix> list = new ArrayList<ImageMatrix>();
public ImageMatrixDB load(String f) throws IOException {
try (
FileReader fr = new FileReader(f);
BufferedReader br = new BufferedReader(fr)) {
String line = null;
while((line = br.readLine()) != null) {
int lastComma = line.lastIndexOf(',');
int classCode = Integer.parseInt(line.substring(1 + lastComma));
int[] data = Arrays.stream(line.substring(0, lastComma).split(","))
.mapToInt(Integer::parseInt)
.toArray();
ImageMatrix matrix = new ImageMatrix(data, classCode);
list.add(matrix);
}
}
return this;
}
public void printResults(){ //output results
for(ImageMatrix matrix: list){
System.out.println(matrix);
}
}
public Iterator<ImageMatrix> iterator() {
return this.list.iterator();
}
/// kNN implementation ///
public static int distance(int[] a, int[] b) {
int sum = 0;
for(int i = 0; i < a.length; i++) {
sum += (a[i] - b[i]) * (a[i] - b[i]);
}
return (int)Math.sqrt(sum); //Euclidean sqrt of the sum
}
public static int classify(List<ImageMatrix> trainingSet, int[] curData) {
int label = 0, bestDistance = Integer.MAX_VALUE;
for(ImageMatrix matrix: trainingSet) {
int dist = distance(matrix.getData(), curData);
if(dist < bestDistance) {
bestDistance = dist;
curData = matrix.getData();
}
}
return label;
}
public static void main(String[] argv) throws IOException {
ImageMatrixDB i = new ImageMatrixDB();
List<ImageMatrix> trainingSet = i.load("cw2DataSet1.csv"); // << ERROR HERE
List<ImageMatrix> validationSet = i.load("cw2DataSet2.csv"); //<< ERROR HERE
int numCorrect = 0;
for(ImageMatrix matrix:validationSet) {
if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;
}
System.out.println("Accuracy: " + (double)numCorrect / validationSet.size() * 100 + "%");
}
//////////////////////////////////////////
// Previous working dataset Load //
/* public static void main(String[] args){
ImageMatrixDB i = new ImageMatrixDB();
try{
i.load("cw2DataSet1.csv");
i.printResults();
}
catch(Exception ex){
ex.printStackTrace();
}
} */
}
编辑///
目前的错误信息是:
Exception in thread "main" java.lang.Error: Unresolved compilation problems:
Type mismatch: cannot convert from ImageMatrixDB to List<ImageMatrix>
Type mismatch: cannot convert from ImageMatrixDB to List<ImageMatrix>
at ImageMatrixDB.main(ImageMatrixDB.java:64)
但是我在测试的时候还报了其他错误。
你设计你的class的方式,应该按如下方式使用:
ImageMatrixDB trainingSet = new ImageMatrixDB();
ImageMatrixDB validationSet = new ImageMatrixDB();
trainingSet.load("cw2DataSet1.csv");
validationSet.load("cw2DataSet2.csv");
注意 ImageMatrixDB 的两个实例而不是一个,这确保了训练/验证数据被加载到不同的列表中。
附带说明一下,在计算 kNN 中的距离时,您应该能够使用平方距离(效率增益,sqrt 是一项昂贵的操作)。所以 return (int)Math.sqrt(sum);
不需要平方根。
我已将 KNN 算法应用于 class 化手写数字。这些数字最初是 8*8 的矢量格式,并被拉伸形成一个 1*64 的矢量,每组数据的 class 代码为 0..9。
据我所知,我的代码在理论上应该可以工作,但这是我第一次试验这种算法。当我尝试通过我的算法输入我的数据集时,我的问题源于我在代码中突出显示的行上抛出了一个错误。可以找到训练数据集here and the validation set here。如果有帮助,我也留在了我以前的工作主要功能中。
ImageMatrix.java
import java.util.*;
public class ImageMatrix {
private int[] data;
private int classCode;
public ImageMatrix(int[] data, int classCode) {
assert data.length == 64; //maximum array length of 64
this.data = data;
this.classCode = classCode;
}
public String toString() {
return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable
}
public int[] getData() {
return data;
}
public int getClassCode() {
return classCode;
}
}
ImageMatrixDB.java
import java.util.*;
import java.io.*;
public class ImageMatrixDB implements Iterable<ImageMatrix> {
private List<ImageMatrix> list = new ArrayList<ImageMatrix>();
public ImageMatrixDB load(String f) throws IOException {
try (
FileReader fr = new FileReader(f);
BufferedReader br = new BufferedReader(fr)) {
String line = null;
while((line = br.readLine()) != null) {
int lastComma = line.lastIndexOf(',');
int classCode = Integer.parseInt(line.substring(1 + lastComma));
int[] data = Arrays.stream(line.substring(0, lastComma).split(","))
.mapToInt(Integer::parseInt)
.toArray();
ImageMatrix matrix = new ImageMatrix(data, classCode);
list.add(matrix);
}
}
return this;
}
public void printResults(){ //output results
for(ImageMatrix matrix: list){
System.out.println(matrix);
}
}
public Iterator<ImageMatrix> iterator() {
return this.list.iterator();
}
/// kNN implementation ///
public static int distance(int[] a, int[] b) {
int sum = 0;
for(int i = 0; i < a.length; i++) {
sum += (a[i] - b[i]) * (a[i] - b[i]);
}
return (int)Math.sqrt(sum); //Euclidean sqrt of the sum
}
public static int classify(List<ImageMatrix> trainingSet, int[] curData) {
int label = 0, bestDistance = Integer.MAX_VALUE;
for(ImageMatrix matrix: trainingSet) {
int dist = distance(matrix.getData(), curData);
if(dist < bestDistance) {
bestDistance = dist;
curData = matrix.getData();
}
}
return label;
}
public static void main(String[] argv) throws IOException {
ImageMatrixDB i = new ImageMatrixDB();
List<ImageMatrix> trainingSet = i.load("cw2DataSet1.csv"); // << ERROR HERE
List<ImageMatrix> validationSet = i.load("cw2DataSet2.csv"); //<< ERROR HERE
int numCorrect = 0;
for(ImageMatrix matrix:validationSet) {
if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;
}
System.out.println("Accuracy: " + (double)numCorrect / validationSet.size() * 100 + "%");
}
//////////////////////////////////////////
// Previous working dataset Load //
/* public static void main(String[] args){
ImageMatrixDB i = new ImageMatrixDB();
try{
i.load("cw2DataSet1.csv");
i.printResults();
}
catch(Exception ex){
ex.printStackTrace();
}
} */
}
编辑///
目前的错误信息是:
Exception in thread "main" java.lang.Error: Unresolved compilation problems:
Type mismatch: cannot convert from ImageMatrixDB to List<ImageMatrix>
Type mismatch: cannot convert from ImageMatrixDB to List<ImageMatrix>
at ImageMatrixDB.main(ImageMatrixDB.java:64)
但是我在测试的时候还报了其他错误。
你设计你的class的方式,应该按如下方式使用:
ImageMatrixDB trainingSet = new ImageMatrixDB();
ImageMatrixDB validationSet = new ImageMatrixDB();
trainingSet.load("cw2DataSet1.csv");
validationSet.load("cw2DataSet2.csv");
注意 ImageMatrixDB 的两个实例而不是一个,这确保了训练/验证数据被加载到不同的列表中。
附带说明一下,在计算 kNN 中的距离时,您应该能够使用平方距离(效率增益,sqrt 是一项昂贵的操作)。所以 return (int)Math.sqrt(sum);
不需要平方根。