使用 LAPACK dgesvd_ 在非方矩阵上进行 SVD

SVD on a non-square matrix using LAPACK dgesvd_

我必须在非方阵上计算 SVD。为此,我正在使用 LAPACK 的 dgesvd_ 例程。与 MATLAB 相比,我对方阵没有任何问题,我收到了预期值。但是我无法为 4x5 矩阵产生预期的结果。我知道解决方案应该与 MATLAB 的解决方案相匹配,因为返回的奇异值按降序排序。我可以看到一些奇异值可以在 SVD 的原始 A 输入数组中找到。这表明我必须调用 dgesvd_ 错误或者我错误地引用了结果,这可能与前导数组维度有关。

在每种情况下,我首先使用 LWORK = -1 发​​出调用,查询 LAPACK 以获得最佳值,这些值是下一次计算 SVD 的调用的下一个输入。我不确定返回值的所有含义以及它们是否有效,是否应该更改等。我认为它们没问题,所以我在接下来的调用中使用它们来计算 SVD。

因此此代码按预期工作(3x3 矩阵):

 41 /* Reference data. */
 42 double ref_array_A[3][3] = {
 43     { 1, 2, 3},
 44     { 2, 4, 5 },
 45     { 3, 5, 6 }
 46 };
 47 
 48 double ref_array_U[3][3] = {
 49     { -0.327985, -0.736976, -0.591009 },
 50     { -0.591009, -0.327985, 0.736976 },
 51     { -0.736976, 0.591009, -0.327985 }
 52 };
 53 
 54 double ref_array_Sigma[3][1] = {
 55     { 11.344814 },
 56     { 0.515729 },
 57     { 0.170915 }
 58 };
 59 
 60 double ref_array_VT[3][3] = {
 61     { -0.327985, -0.591009, -0.736976 },
 62     { 0.736976, 0.327985, -0.591009 },
 63     { -0.591009, 0.736976, -0.327985 }
 64 };
 66 /* MATLAB result
 67  *
 68  *  >> A = [ 1, 2, 3; 2, 4, 5; 3, 5, 6]
 69  *
 70  *  A = 
 71  *      1     2     3
 72  *      2     4     5
 73  *      3     5     6
 74  *
 75  *  >> [U, S, V] = svd(A)
 76  *
 77  *  U =
 78  *      -0.3280   -0.7370   -0.5910
 79  *      -0.5910   -0.3280    0.7370
 80  *      -0.7370    0.5910   -0.3280
 81  *
 82  *  S =
 83  *      11.3448     0           0
 84  *      0           0.5157      0
 85  *      0           0           0.1709
 86  *
 87  *  V =
 88  *      -0.3280    0.7370   -0.5910
 89  *      -0.5910    0.3280    0.7370
 90  *      -0.7370   -0.5910   -0.3280
 91  */
double WORK_QUERY = 0;
206 
207 
208     /* Call dgesvd_ with lwork = -1 to query optimal workspace size. */
209 
210     JOBU = 'A';
211     JOBVT = 'A';
212     M = 3;
213     N = 3;
214     LDA = 3;            /* (out) */
215     LDU = 3;            /* (out) */
216     S = NULL;           /* (don't care) */
217     U = NULL;           /* (don't care) */
218     VT = NULL;          /* (don't care) */
219     LDVT = 3;           /* (out) */
220     WORK = NULL;        /* (out) , because LWORK is 0 do not care */
221     LWORK = 4 * M * N * M *N + 6 * M * N + dd_max(M, N);
222 
223     A = calloc(M * N, sizeof(double));
224     if (!A) {
225         goto ddt2_fail_sys;
226     }
227     for (i = 0; i < M; ++i) {
228         for (j = 0; j < N; ++j) {
229             A[i * N + j] = ref_array_A[i][j];
230         }
231     }
232 
233     S = calloc(dd_min(M, N), sizeof(double));
234     if (!S) {
235         goto ddt2_fail_sys;
236     }
237 
238     U = calloc(LDU * M, sizeof(double));
239     if (!U) {
240         goto ddt2_fail_sys;
241     }
242 
243     VT = calloc(LDVT * N, sizeof(double));
244     if (!A) {
245         goto ddt2_fail_sys;
246     }
247 
248     fprintf(stderr, "Reference array A:\n");
249     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
250 
251     fprintf(stderr, "Reference array U:\n");
252     dd_walk_dbl_arr_rowwise(&ref_array_U[0][0], M, M, cb_dbl, cb_dbl_row_end);
253 
254     fprintf(stderr, "Reference array Sigma:\n");
255     dd_walk_dbl_arr_rowwise(&ref_array_Sigma[0][0], dd_min(M, N), 1, cb_dbl, cb_dbl_row_end);
256 
257     fprintf(stderr, "Reference array VT:\n");
258     dd_walk_dbl_arr_rowwise(&ref_array_VT[0][0], N, N, cb_dbl, cb_dbl_row_end);
LWORK = -1;
261     dgesvd_("A", "A", &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, &WORK_QUERY, &LWORK, &INFO);
262     if (INFO != 0) {
263         if (INFO < 0) {
264             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"the %d-th argument had illegal value\"\n", INFO);
265         } else {
266             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"DBDSDC didn't converge, updating process failed\"\n");
267         }
268         return -1;
269     }
270 
271     LWORK = (int) WORK_QUERY;
272     WORK = calloc(LWORK, sizeof(double));
273     if (!WORK) {
274         goto ddt2_fail_sys;
275     }
276 
277     fprintf(stderr, "LAPACK's dgesvd_ query optimal results: LDA %d, LDU %d, LDVT %d, LWORK %d, WORK_QUERY %f\n", LDA, LDU, LDVT, LWORK, WORK_QUERY);
278     fprintf(stderr, "Rest of params: M %d, N %d\n", M, N);
279 
280     /* Compute SVD. */
281     dgesvd_(&JOBU, &JOBVT, &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, WORK, &LWORK, &INFO);
282     if (INFO != 0) {
283         if (INFO < 0) {
284             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"the %d-th argument had illegal value\"\n", INFO);
285         } else {
286             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"DBDSDC didn't converge, updating process failed\"\n");
287         }
288         return -1;
289     }
290 
291     fprintf(stderr, "LAPACK's dgesvd_ SVD completed\n");
292 
293     fprintf(stderr, "Result A:\n");
294     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
295 
296     fprintf(stderr, "Result U**T:\n");
297     dd_walk_dbl_arr_rowwise(U, LDU, M, cb_dbl, cb_dbl_row_end);
298     fprintf(stderr, "Result U:\n");
299     dd_walk_dbl_arr_colwise(U, LDU, M, cb_dbl, cb_dbl_row_end);
300 
301 
302     fprintf(stderr, "Result S:\n");
303     dd_walk_dbl_arr_rowwise(S, dd_min(M, N), 1, cb_dbl, cb_dbl_row_end);
304 
305     fprintf(stderr, "Result VT:\n");
306     dd_walk_dbl_arr_rowwise(VT, LDVT, N, cb_dbl, cb_dbl_row_end);
307 
308     free(WORK);
309     free(A);
310     free(S);
311     free(U);
312     free(VT);
313 
314     return 0;

正确结果:

peter@xx:~$ ./test4
Reference array A:
    1.000000    2.000000    3.000000
    2.000000    4.000000    5.000000
    3.000000    5.000000    6.000000
Reference array U:
    -0.327985   -0.736976   -0.591009
    -0.591009   -0.327985   0.736976
    -0.736976   0.591009    -0.327985
Reference array Sigma:
    11.344814
    0.515729
    0.170915
Reference array VT:
    -0.327985   -0.591009   -0.736976
    0.736976    0.327985    -0.591009
    -0.591009   0.736976    -0.327985
LAPACK's dgesvd_ query optimal results: LDA 3, LDU 3, LDVT 3, LWORK 201, WORK_QUERY 201.000000
Rest of params: M 3, N 3
LAPACK's dgesvd_ SVD completed
Result A:
    -3.741657   0.421793    0.632690
    10.643576   1.261481    -0.720622
    0.478213    -0.279401   -0.211863
Result U**T:
    -0.327985   -0.591009   -0.736976
    -0.736976   -0.327985   0.591009
    -0.591009   0.736976    -0.327985
Result U:
    -0.327985   -0.736976   -0.591009
    -0.591009   -0.327985   0.736976
    -0.736976   0.591009    -0.327985
Result S:
    11.344814
    0.515729
    0.170915
Result VT:
    -0.327985   0.736976    -0.591009
    -0.591009   0.327985    0.736976
    -0.736976   -0.591009   -0.327985

但不是这个(4x5 矩阵):

 39 /* Reference data. */
 40 double ref_array_A[4][5] = {
 41     { 1, 0, 0, 0, 2 },
 42     { 0, 0, 3, 0, 0 },
 43     { 0, 0, 0, 0, 0 },
 44     { 0, 2, 0, 0, 0 }
 45 };
 46 
 47 double ref_array_U[4][4] = {
 48     { 0, 0, 1, 0 },
 49     { 0, 1, 0, 0 },
 50     { 0, 0, 0, -1 },
 51     { 1, 0, 0, 0 }
 52 };
 53 
 54 double ref_array_Sigma[4][5] = {
 55     { 2, 0, 0, 0, 0 },
 56     { 0, 3, 0, 0, 0 },
 57     { 0, 0, 2.236068, 0, 0 },
 58     { 0, 0, 0, 0, 0 }
 59 };
 60 
 61 double ref_array_VT[5][5] = {
 62     { 0, 1, 0, 0, 0 },
 63     { 0, 0, 1, 0, 0 },
 64     { 0.447214, 0, 0, 0, 0.894427 },
 65     { 0, 0, 0, 1, 0 },
 66     { -0.894427, 0, 0, 0, -0.447214 }
 67 };
 68 
 69 /* MATLAB result
 70  *
 71  * >> A = [ 1 0 0 0 2; 0 0 3 0 0 ; 0 0 0 0 0 ;0 2 0 0 0 ];
 72  * >> [U, S, V] = svd(A)
 73  *
 74  * U =
 75  *      0     1     0     0
 76  *      1     0     0     0
 77  *      0     0     0    -1
 78  *      0     0     1     0
 79  *
 80  * S =
 81  *      3.0000      0           0           0           0
 82  *      0           2.2361      0           0           0
 83  *      0           0           2.0000      0           0
 84  *      0           0           0           0           0
 85  *
 86  * V =
 87  *      0           0.4472      0           0           -0.8944
 88  *      0           0           1.0000      0           0
 89  *      1.0000      0           0           0           0
 90  *      0           0           0           1.0000      0
 91  *      0           0.8944      0           0           0.4472
 92  */
double WORK_QUERY = 0;
206 
207 
208     /* Call dgesvd_ with lwork = -1 to query optimal workspace size. */
209 
210     JOBU = 'A';
211     JOBVT = 'A';
212     M = 4;
213     N = 5;
214     LDA = 4;            /* (out) */
215     LDU = 4;            /* (out) */
216     S = NULL;           /* (don't care) */
217     U = NULL;           /* (don't care) */
218     VT = NULL;          /* (don't care) */
219     LDVT = 5;           /* (out) */
220     WORK = NULL;        /* (out) , because LWORK is 0 do not care */
221     LWORK = 4 * M * N * M *N + 6 * M * N + dd_max(M, N);
222 
223     A = calloc(M * N, sizeof(double));
224     if (!A) {
225         goto ddt2_fail_sys;
226     }
227     for (i = 0; i < M; ++i) {
228         for (j = 0; j < N; ++j) {
229             A[i * N + j] = ref_array_A[i][j];
230         }
231     }
232 
233     S = calloc(M * N, sizeof(double));
234     if (!S) {
235         goto ddt2_fail_sys;
236     }
237 
238     U = calloc(LDU * M, sizeof(double));
239     if (!U) {
240         goto ddt2_fail_sys;
241     }
242 
243     VT = calloc(LDVT * N, sizeof(double));
244     if (!A) {
245         goto ddt2_fail_sys;
246     }
247 
248     fprintf(stderr, "Reference array A:\n");
249     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
250 
251     fprintf(stderr, "Reference array U:\n");
252     dd_walk_dbl_arr_rowwise(&ref_array_U[0][0], M, M, cb_dbl, cb_dbl_row_end);
253 
254     fprintf(stderr, "Reference array Sigma:\n");
255     dd_walk_dbl_arr_rowwise(&ref_array_Sigma[0][0], M, N, cb_dbl, cb_dbl_row_end);
256 
257     fprintf(stderr, "Reference array VT:\n");
258     dd_walk_dbl_arr_rowwise(&ref_array_VT[0][0], N, N, cb_dbl, cb_dbl_row_end);
259 
260     LWORK = -1;
261     dgesvd_("A", "A", &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, &WORK_QUERY, &LWORK, &INFO);
if (INFO != 0) {
263         if (INFO < 0) {
264             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"the %d-th argument had illegal value\"\n", INFO);
265         } else {
266             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"DBDSDC didn't converge, updating process failed\"\n");
267         }
268         return -1;
269     }
270 
271     LWORK = (int) WORK_QUERY;
272     WORK = calloc(LWORK, sizeof(double));
273     if (!WORK) {
274         goto ddt2_fail_sys;
275     }
276 
277     fprintf(stderr, "LAPACK's dgesvd_ query optimal results: LDA %d, LDU %d, LDVT %d, LWORK %d, WORK_QUERY %f\n", LDA, LDU, LDVT, LWORK, WORK_QUERY);
278     fprintf(stderr, "Rest of params: M %d, N %d\n", M, N);
279 
280     /* Compute SVD. */
281     dgesvd_(&JOBU, &JOBVT, &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, WORK, &LWORK, &INFO);
282     if (INFO != 0) {
283         if (INFO < 0) {
284             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"the %d-th argument had illegal value\"\n", INFO);
285         } else {
286             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"DBDSDC didn't converge, updating process failed\"\n");
287         }
288         return -1;
289     }
290 
291     fprintf(stderr, "LAPACK's dgesvd_ SVD completed\n");
292 
293     fprintf(stderr, "Result A:\n");
294     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
295 
296     fprintf(stderr, "Result U:\n");
297     dd_walk_dbl_arr_rowwise(U, LDU, M, cb_dbl, cb_dbl_row_end);
298 
299     fprintf(stderr, "Result S:\n");
300     dd_walk_dbl_arr_rowwise(S, M, N, cb_dbl, cb_dbl_row_end);
301 
302     fprintf(stderr, "Result VT:\n");
303     dd_walk_dbl_arr_rowwise(VT, LDVT, N, cb_dbl, cb_dbl_row_end);
304 
305     free(WORK);
306     free(A);
307     free(S);
308     free(U);
309     free(VT);
310 
311     return 0;

错误结果:

peter@xx:~/$ ./test2
Reference array A:
    1.000000    0.000000    0.000000    0.000000    2.000000
    0.000000    0.000000    3.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
    0.000000    2.000000    0.000000    0.000000    0.000000
Reference array U:
    0.000000    0.000000    1.000000    0.000000
    0.000000    1.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    -1.000000
    1.000000    0.000000    0.000000    0.000000
Reference array Sigma:
    2.000000    0.000000    0.000000    0.000000    0.000000
    0.000000    3.000000    0.000000    0.000000    0.000000
    0.000000    0.000000    2.236068    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
Reference array VT:
    0.000000    1.000000    0.000000    0.000000    0.000000
    0.000000    0.000000    1.000000    0.000000    0.000000
    0.447214    0.000000    0.000000    0.000000    0.894427
    0.000000    0.000000    0.000000    1.000000    0.000000
    -0.894427   0.000000    0.000000    0.000000    -0.447214
LAPACK's dgesvd_ query optimal results: LDA 4, LDU 4, LDVT 5, LWORK 300, WORK_QUERY 300.000000
Rest of params: M 4, N 5
LAPACK's dgesvd_ SVD completed
Result A:
    -3.000000   -2.000000   0.000000    -1.000000   0.500000
    -2.236068   0.000000    0.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
    0.000000    0.500000    -0.236068   0.000000    0.000000
Result U:
    0.707107    0.000000    0.000000    0.707107
    -0.707107   0.000000    -0.000000   0.707107
    0.000000    0.000000    1.000000    0.000000
    0.000000    1.000000    0.000000    0.000000
Result S:
    3.872983    1.732051    0.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
Result VT:
    0.182574    -0.408248   0.000000    0.000000    -0.894427
    0.912871    0.408248    0.000000    0.000000    0.000000
    -0.000000   -0.000000   1.000000    0.000000    0.000000
    -0.000000   -0.000000   0.000000    1.000000    0.000000
    0.365148    -0.816497   0.000000    0.000000    0.447214

一般矩阵情况我做错了什么?

函数 dgesvd_ 需要 column-major 顺序的矩阵,而您的代码以 row-major 样式提供数据:

227     for (i = 0; i < M; ++i) {
228         for (j = 0; j < N; ++j) {
229             A[i * N + j] = ref_array_A[i][j];
230         }
231     }

实际上,您的代码正在计算

的 SVD
[ 1 2 0 0 2 ]   [ 1 0 0 0 ] ^ T
[ 0 0 0 0 0 ] = [ 2 0 0 3 ]
[ 0 0 0 0 0 ]   [ 0 0 0 0 ]
[ 0 3 0 0 0 ]   [ 2 0 0 0 ]

确实产生了大约 3.87, 1.73

第一个示例中不会出现此错误,因为矩阵是方阵 (M=N) 且对称。

此外,参数 S 应该只是 one-dimensional 数组(如您的第一个示例)。由于您随后以 row-major 格式和 dd_walk_dbl_arr_rowwise(S, M, N, cb_dbl, cb_dbl_row_end); 打印它,因此这些值连续出现在第一行中...