java.lang.IllegalArgumentException: 矩阵内部维度必须一致
java.lang.IllegalArgumentException: Matrix inner dimensions must agree
这是我的代码:
package algorithms;
import Jama.Matrix;
import java.io.File;
import java.util.Arrays;
public class ThetaGetter {
//First column is one, second is price and third is BHK
private static double[][] variables = {
{1,1130,2},
{1,1100,2},
{1,2055,3},
{1,1047,2},
{1,1927,3},
{1,2667,3},
{1,1146,2},
{1,2020,3},
{1,1190,2},
{1,2165,3},
{1,1250,2},
{1,1185,2},
{1,2825,4},
{1,1200,2},
{1,1580,3},
{1,3200,3},
{1,715,1},
{1,1270,2},
{1,2403,3},
{1,1465,3},
{1,1345,2}
};
private static double[][] prices = {
{69.65},
{60},
{115},
{55},
{140},
{225},
{76.78},
{120},
{73.11},
{140},
{56},
{79.39},
{161},
{73.69},
{80},
{145},
{34.87},
{77.72},
{165},
{98},
{82}
};
private static Matrix X = new Matrix(variables);
private static Matrix y = new Matrix(prices);
public static void main(String[] args) {
File file = new File("theta.dat");
if(file.exists()){
System.out.println("Theta has already been calculated!");
return;
}
//inverse(Tra(X)*X)*tra(X)*y
Matrix transposeX = X.transpose();
Matrix inverse = X.times(transposeX).inverse();
System.out.println(y.getArray().length);
System.out.println(X.getArray().length);
Matrix test = inverse.times(transposeX);
Matrix theta = test.times(y);
System.out.println(Arrays.deepToString(theta.getArray()));
}
}
这个算法基本上是尝试获取房价,然后得到一些常数,然后用这些常数来猜测房价。但是,我在 'Matrix theta = test.times(y);' 行遇到异常,错误消息几乎就是问题所在。尺寸是否存在某种问题?两个都是21项,不知道怎么回事
您犯的错误在以下代码行中:
Matrix inverse = X.times(transposeX).inverse();
您在上面评论的公式是:
//inverse(Tra(X)*X)*tra(X)*y
但是你在代码中实际计算的是:
(X*Tra(X) 而不是 Tra(X)*X)
//inverse(X*Tra(X))*tra(X)*y
如果 X 的维度是 (m,n) 其中
- m = 行数
- n = 列数
并且 Y 的维度是 (m,1),使用上面使用的乘法,您将得到以下结果:
逆(X * Tra(X)) *Tra(X)*Y = 逆* Tra(X) * Y = 结果* y
逆((m,n)(n,m))(n,m)*(m,1)= (m,m) * (n,m ) => 这会导致错误,因为矩阵乘法的内部维数必须相等
修复您的代码的方法是替换以下行:
Matrix inverse = X.times(transposeX).inverse();
和
Matrix inverse = transposeX.times(X).inverse();
这是我的代码:
package algorithms;
import Jama.Matrix;
import java.io.File;
import java.util.Arrays;
public class ThetaGetter {
//First column is one, second is price and third is BHK
private static double[][] variables = {
{1,1130,2},
{1,1100,2},
{1,2055,3},
{1,1047,2},
{1,1927,3},
{1,2667,3},
{1,1146,2},
{1,2020,3},
{1,1190,2},
{1,2165,3},
{1,1250,2},
{1,1185,2},
{1,2825,4},
{1,1200,2},
{1,1580,3},
{1,3200,3},
{1,715,1},
{1,1270,2},
{1,2403,3},
{1,1465,3},
{1,1345,2}
};
private static double[][] prices = {
{69.65},
{60},
{115},
{55},
{140},
{225},
{76.78},
{120},
{73.11},
{140},
{56},
{79.39},
{161},
{73.69},
{80},
{145},
{34.87},
{77.72},
{165},
{98},
{82}
};
private static Matrix X = new Matrix(variables);
private static Matrix y = new Matrix(prices);
public static void main(String[] args) {
File file = new File("theta.dat");
if(file.exists()){
System.out.println("Theta has already been calculated!");
return;
}
//inverse(Tra(X)*X)*tra(X)*y
Matrix transposeX = X.transpose();
Matrix inverse = X.times(transposeX).inverse();
System.out.println(y.getArray().length);
System.out.println(X.getArray().length);
Matrix test = inverse.times(transposeX);
Matrix theta = test.times(y);
System.out.println(Arrays.deepToString(theta.getArray()));
}
}
这个算法基本上是尝试获取房价,然后得到一些常数,然后用这些常数来猜测房价。但是,我在 'Matrix theta = test.times(y);' 行遇到异常,错误消息几乎就是问题所在。尺寸是否存在某种问题?两个都是21项,不知道怎么回事
您犯的错误在以下代码行中:
Matrix inverse = X.times(transposeX).inverse();
您在上面评论的公式是:
//inverse(Tra(X)*X)*tra(X)*y
但是你在代码中实际计算的是: (X*Tra(X) 而不是 Tra(X)*X)
//inverse(X*Tra(X))*tra(X)*y
如果 X 的维度是 (m,n) 其中
- m = 行数
- n = 列数
并且 Y 的维度是 (m,1),使用上面使用的乘法,您将得到以下结果:
逆(X * Tra(X)) *Tra(X)*Y = 逆* Tra(X) * Y = 结果* y
逆((m,n)(n,m))(n,m)*(m,1)= (m,m) * (n,m ) => 这会导致错误,因为矩阵乘法的内部维数必须相等
修复您的代码的方法是替换以下行:
Matrix inverse = X.times(transposeX).inverse();
和
Matrix inverse = transposeX.times(X).inverse();