发布 kNN 算法的加载数据集 - Java

Issue loading data sets for kNN algorithm - Java

我已将 KNN 算法应用于 class 化手写数字。这些数字最初是 8*8 的矢量格式,并被拉伸形成一个 1*64 的矢量,每组数据的 class 代码为 0..9。

据我所知,我的代码在理论上应该可以工作,但这是我第一次试验这种算法。当我尝试通过我的算法输入我的数据集时,我的问题源于我在代码中突出显示的行上抛出了一个错误。可以找到训练数据集here and the validation set here。如果有帮助,我也留在了我以前的工作主要功能中。

ImageMatrix.java

import java.util.*;

public class ImageMatrix {
    private int[] data;
    private int classCode;

public ImageMatrix(int[] data, int classCode) {
    assert data.length == 64; //maximum array length of 64
    this.data = data;
    this.classCode = classCode;
}

    public String toString() {
        return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable
    }

    public int[] getData() {
        return data;
    }

    public int getClassCode() {
        return classCode;
    }

}

ImageMatrixDB.java

import java.util.*;
import java.io.*;

public class ImageMatrixDB implements Iterable<ImageMatrix> {
    private List<ImageMatrix> list = new ArrayList<ImageMatrix>();

    public ImageMatrixDB load(String f) throws IOException {
        try (
            FileReader fr = new FileReader(f);
            BufferedReader br = new BufferedReader(fr)) {
            String line = null;

            while((line = br.readLine()) != null) {
                int lastComma = line.lastIndexOf(',');
                int classCode = Integer.parseInt(line.substring(1 + lastComma));
                int[] data = Arrays.stream(line.substring(0, lastComma).split(","))
                                   .mapToInt(Integer::parseInt)
                                   .toArray();
                ImageMatrix matrix = new ImageMatrix(data, classCode);
                list.add(matrix);
            }
        }
        return this;
    }

    public void printResults(){ //output results 
        for(ImageMatrix matrix: list){
            System.out.println(matrix);
        }
    }


    public Iterator<ImageMatrix> iterator() {
        return this.list.iterator();
    }

    /// kNN implementation ///
    public static int distance(int[] a, int[] b) {
        int sum = 0;
        for(int i = 0; i < a.length; i++) {
            sum += (a[i] - b[i]) * (a[i] - b[i]);
        }
        return (int)Math.sqrt(sum); //Euclidean sqrt of the sum 
    }


    public static int classify(List<ImageMatrix> trainingSet, int[] curData) {
        int label = 0, bestDistance = Integer.MAX_VALUE;
        for(ImageMatrix matrix: trainingSet) {
            int dist = distance(matrix.getData(), curData);
            if(dist < bestDistance) {
                bestDistance = dist;
                curData = matrix.getData();
            }
        }
        return label;
    }


    public static void main(String[] argv) throws IOException {
        ImageMatrixDB i = new ImageMatrixDB();
        List<ImageMatrix> trainingSet = i.load("cw2DataSet1.csv"); // << ERROR HERE
        List<ImageMatrix> validationSet = i.load("cw2DataSet2.csv"); //<< ERROR HERE
        int numCorrect = 0;
        for(ImageMatrix matrix:validationSet) {
            if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;
        }
        System.out.println("Accuracy: " + (double)numCorrect / validationSet.size() * 100 + "%");
    }
    //////////////////////////////////////////

    // Previous working dataset Load //
 /*   public static void main(String[] args){
        ImageMatrixDB i = new ImageMatrixDB();
        try{
            i.load("cw2DataSet1.csv"); 
            i.printResults();
        }
        catch(Exception ex){
            ex.printStackTrace();
        }
    } */

}

编辑///

目前的错误信息是:

Exception in thread "main" java.lang.Error: Unresolved compilation problems: 
    Type mismatch: cannot convert from ImageMatrixDB to List<ImageMatrix>
    Type mismatch: cannot convert from ImageMatrixDB to List<ImageMatrix>
    at ImageMatrixDB.main(ImageMatrixDB.java:64)

但是我在测试的时候还报了其他错误。

你设计你的class的方式,应该按如下方式使用:

ImageMatrixDB trainingSet = new ImageMatrixDB();
ImageMatrixDB validationSet = new ImageMatrixDB();
trainingSet.load("cw2DataSet1.csv");
validationSet.load("cw2DataSet2.csv");

注意 ImageMatrixDB 的两个实例而不是一个,这确保了训练/验证数据被加载到不同的列表中。

附带说明一下,在计算 kNN 中的距离时,您应该能够使用平方距离(效率增益,sqrt 是一项昂贵的操作)。所以 return (int)Math.sqrt(sum); 不需要平方根。