如何循环确定 class KNN 方法

How to make a loop for determining class KNN Method

我正在尝试创建一个关于 KNN 方法的程序,但现在我陷入了确定 class 的循环。这是我的编码:

public static void main(String[] args) {
    int titikx, titiky, k;
    int[] titikxl = new int[]{1, 1, 3, 2, 4, 6}; //first feature
    int[] titiky1 = new int[]{1, 3, 1, 5, 3, 2}; //second feature
    ArrayList<Double> nx = new ArrayList<Double>(), ny = new ArrayList<Double>(), 
    fn = new ArrayList<Double>(), arclass1 = new ArrayList<Double>(), 
    arclass2 = new ArrayList<Double>();

    //input hew data's features and k
    Scanner input = new Scanner(System.in);
    System.out.println("Input first feature : ");
    titikx = input.nextInt();
    System.out.println("Input second feature : ");
    titiky = input.nextInt();
    System.out.println("Input k : ");
    k = input.nextInt();

    //count distance between new data and training data
    int i = 0, j = 0;
    while(i < titikxl.length || j <titiky1.length){
        nx.add(Math.pow(titikx - titikxl[i], 2));
        ny.add(Math.pow(titiky - titiky1[j], 2));
        i++;
        j++;
    }

    System.out.println(nx);
    System.out.println(ny);

    //convert arraylist to array
    Double[] nxarray = nx.toArray(new Double[nx.size()]);
    Double[] nyarray = ny.toArray(new Double[ny.size()]);

    //sum the array of distance first feature and second feature to get result

    int ii = 0, jj = 0;
    while (ii < nxarray.length || jj < nyarray.length){
        fn.add(Math.sqrt(nxarray[ii] + nyarray[jj]));
        ii++;
        jj++;
    }

    System.out.println(fn);
    Double[] fnarray = fn.toArray(new Double[fn.size()]);
    Double[] oldfnarray = fnarray; //array result before sort ascending

    //ascending array
    for(int id1 = 0; id1 < fnarray.length; id1++){
        for(int id2 = id1 + 1; id2 < fnarray.length; id2++){
            if(fnarray[id1]>fnarray[id2]){
                double temp = fnarray[id2];
                fnarray[id2] = fnarray[id1];
                fnarray[id1] = temp;
            }
        }
    }

    for(int id = 0; id < fnarray.length; id++){
        System.out.print(fnarray[id] + " ");
    }
    System.out.println();

    double[] classa = new double[]{oldfnarray[0], oldfnarray[1], oldfnarray[2]};
    double[] classb = new double[]{oldfnarray[3], oldfnarray[4], oldfnarray[5]};

    //determining what class the new data belongs
    for(int idb = 0; idb < classa.length; idb++){
        for(int idc = 0; idc < classb.length; idc++){
            for(int ida = 0; ida < fnarray.length; ida++){
                while(ida < k){
                    if (classa[idb] == fnarray[ida]){
                        arclass1.add(fnarray[ida]);
                    } 
                    if (classb[idc] == fnarray[ida]){
                        arclass2.add(fnarray[ida]);
                    }
                }    
            }
        }
    }

    if(arclass1.size() < arclass2.size()){
        System.out.println("The Class is B");
    } else{
        System.out.println("The Class is A");
    }
}

结果是:

Input first feature : 2
Input second feature : 3
Input k : 3
[1.0, 1.0, 1.0, 0.0, 4.0, 16.0] //distance feature x
[4.0, 0.0, 4.0, 4.0, 0.0, 1.0] //distance feature y
[2.23606797749979, 1.0, 2.23606797749979, 2.0, 2.0, 4.123105625617661] //result
1.0 2.0 2.0 2.23606797749979 2.23606797749979 4.123105625617661 //ascended result
//looping forever 
BUILD STOPPED (total time: 35 seconds)

如您所见,在确定 class 的部分中,当进入循环时,程序似乎永远在执行最后一个循环。我尝试尽可能地修改它,但仍然没有结果,只有错误和 运行 永远。请帮忙。非常感谢你。

问题是无限的 while 循环 - ida 的值永远不会改变。我建议修改整个代码块,因为它比需要的更复杂。

在提出解决方案之前,让我们先确定一下我们已经知道的内容:

  • 最近的 k 邻居(已排序的距离数组,索引小于 k
  • 哪些邻居属于哪个 class

很明显,为了确定 class,我们需要:

  • 检查每个邻居(外循环)是否在 class A(我们需要第一个内循环)或 class B(我们需要第二个内循环)
  • 如果发现邻居在 class A,我们增加 class A 的计数器,否则我们为 class B
  • 比较计数器:如果 class A 的计数器更大,则该特征属于 class A,否则属于 class B

为了更好地理解发生了什么,k-NN classification 算法在 this tutorial.

中进行了全面描述

代码(与你的结构相同,尽管我重命名了变量并简化了一些部分以提高可读性):

import java.util.Arrays;
import java.util.Scanner;

public class KNN {

    public static void main(String[] args) {
        int[] feature1 = new int[] { 1, 1, 3, 2, 4, 6 };
        int[] feature2 = new int[] { 1, 3, 1, 5, 3, 2 };

        //input hew data's features and k
        Scanner input = new Scanner(System.in);

        System.out.println("Input first feature : ");
        int newFeature1 = input.nextInt();

        System.out.println("Input second feature : ");
        int newFeature2 = input.nextInt();

        System.out.println("Input k : ");
        int k = input.nextInt();

        input.close();

        //count distance between new data and training data
        double[] distances1 = new double[feature1.length];
        double[] distances2 = new double[feature2.length];

        for (int i = 0; i < distances1.length; i++) {
            distances1[i] = Math.pow(newFeature1 - feature1[i], 2);
        }

        for (int i = 0; i < distances2.length; i++) {
            distances2[i] = Math.pow(newFeature2 - feature2[i], 2);
        }

        System.out.println("Distance between first feature and first feature training data: " + Arrays.toString(distances1));
        System.out.println("Distance between second feature and second feature training data: " + Arrays.toString(distances2));

        //sum the array of distance first feature and second feature to get result
        double[] distanceSums = new double[distances1.length];

        for (int i = 0; i < distances1.length; i++) {
            distanceSums[i] = Math.sqrt(distances1[i] + distances2[i]);
        }

        System.out.println("Distance sums: " + Arrays.toString(distanceSums));

        // sort array ascending
        double[] distanceSumsSorted = new double[distanceSums.length];
        System.arraycopy(distanceSums, 0, distanceSumsSorted, 0, distanceSums.length);
        Arrays.sort(distanceSumsSorted);

        System.out.println("Sorted distance sums: " + Arrays.toString(distanceSumsSorted));

        double[] classAMembers = new double[] { distanceSums[0], distanceSums[1], distanceSums[2] };
        double[] classBMembers = new double[] { distanceSums[3], distanceSums[4], distanceSums[5] };

        //determining what class the new data belongs
        int classACounts = 0;
        int classBCounts = 0;

        for (int i = 0; i < k; i++) {
            // check if nearest neighbor belongs to class A
            for (int j = 0; j < classAMembers.length; j++) {
                if (distanceSumsSorted[i] == classAMembers[j]) {
                    classACounts++;
                    break;
                }
            }
            // check if nearest neighbor belongs to class B
            for (int j = 0; j < classBMembers.length; j++) {
                if (distanceSumsSorted[i] == classBMembers[j]) {
                    classBCounts++;
                    break;
                }
            }
        }

        System.out.println("Class A members: " + Arrays.toString(classAMembers));
        System.out.println("Class B members: " + Arrays.toString(classBMembers));

        System.out.println("Counts for class A: " + classACounts);
        System.out.println("Counts for class B: " + classBCounts);

        if (classACounts < classBCounts){
            System.out.println("The Class is B.");
        }
        else {
            System.out.println("The Class is A.");
        }
    }
}

对于您的示例数据,程序输出:

Input first feature : 2
Input second feature : 3
Input k : 3
Distance between first feature and first feature training data: [1.0, 1.0, 1.0, 0.0, 4.0, 16.0]
Distance between second feature and second feature training data: [4.0, 0.0, 4.0, 4.0, 0.0, 1.0]
Distance sums: [2.23606797749979, 1.0, 2.23606797749979, 2.0, 2.0, 4.123105625617661]
Sorted distance sums: [1.0, 2.0, 2.0, 2.23606797749979, 2.23606797749979, 4.123105625617661]
Class A members: [2.23606797749979, 1.0, 2.23606797749979]
Class B members: [2.0, 2.0, 4.123105625617661]
Counts for class A: 1
Counts for class B: 2
The Class is B.