使用 LAPACK dgesvd_ 的非方阵上的 SVD



我必须在非方阵上计算SVD。我正在使用 LAPACK 的 dgesvd_ 例程。与 MATLAB 相比,我对方阵没有任何问题,我收到了预期值。但我无法为 4x5 矩阵产生预期结果。我知道解决方案应该与 MATLAB 的解决方案相匹配,因为返回的奇异值是按降序排序的。我可以看到,虽然可以在 SVD 的原始 A 输入数组中找到一些奇异值。这表明我必须错误地调用dgesvd_或者我错误地引用了结果,这可能与前导数组维度有关。

在每种情况下,我首先发出一个 LWORK = -1 的调用,向 LAPACK 查询最佳值,这些值是下一个计算 SVD 的调用的下一个输入。我不确定返回值的所有含义以及它们是否有效,是否应该更改等。我假设它们没问题,所以我在下面的调用中使用它们来计算 SVD。

所以这段代码按预期工作(3x3 矩阵):

41 /* Reference data. */
42 double ref_array_A[3][3] = {
43     { 1, 2, 3},
44     { 2, 4, 5 },
45     { 3, 5, 6 }
46 };
47 
48 double ref_array_U[3][3] = {
49     { -0.327985, -0.736976, -0.591009 },
50     { -0.591009, -0.327985, 0.736976 },
51     { -0.736976, 0.591009, -0.327985 }
52 };
53 
54 double ref_array_Sigma[3][1] = {
55     { 11.344814 },
56     { 0.515729 },
57     { 0.170915 }
58 };
59 
60 double ref_array_VT[3][3] = {
61     { -0.327985, -0.591009, -0.736976 },
62     { 0.736976, 0.327985, -0.591009 },
63     { -0.591009, 0.736976, -0.327985 }
64 };
66 /* MATLAB result
67  *
68  *  >> A = [ 1, 2, 3; 2, 4, 5; 3, 5, 6]
69  *
70  *  A = 
71  *      1     2     3
72  *      2     4     5
73  *      3     5     6
74  *
75  *  >> [U, S, V] = svd(A)
76  *
77  *  U =
78  *      -0.3280   -0.7370   -0.5910
79  *      -0.5910   -0.3280    0.7370
80  *      -0.7370    0.5910   -0.3280
81  *
82  *  S =
83  *      11.3448     0           0
84  *      0           0.5157      0
85  *      0           0           0.1709
86  *
87  *  V =
88  *      -0.3280    0.7370   -0.5910
89  *      -0.5910    0.3280    0.7370
90  *      -0.7370   -0.5910   -0.3280
91  */
double WORK_QUERY = 0;
206 
207 
208     /* Call dgesvd_ with lwork = -1 to query optimal workspace size. */
209 
210     JOBU = 'A';
211     JOBVT = 'A';
212     M = 3;
213     N = 3;
214     LDA = 3;            /* (out) */
215     LDU = 3;            /* (out) */
216     S = NULL;           /* (don't care) */
217     U = NULL;           /* (don't care) */
218     VT = NULL;          /* (don't care) */
219     LDVT = 3;           /* (out) */
220     WORK = NULL;        /* (out) , because LWORK is 0 do not care */
221     LWORK = 4 * M * N * M *N + 6 * M * N + dd_max(M, N);
222 
223     A = calloc(M * N, sizeof(double));
224     if (!A) {
225         goto ddt2_fail_sys;
226     }
227     for (i = 0; i < M; ++i) {
228         for (j = 0; j < N; ++j) {
229             A[i * N + j] = ref_array_A[i][j];
230         }
231     }
232 
233     S = calloc(dd_min(M, N), sizeof(double));
234     if (!S) {
235         goto ddt2_fail_sys;
236     }
237 
238     U = calloc(LDU * M, sizeof(double));
239     if (!U) {
240         goto ddt2_fail_sys;
241     }
242 
243     VT = calloc(LDVT * N, sizeof(double));
244     if (!A) {
245         goto ddt2_fail_sys;
246     }
247 
248     fprintf(stderr, "Reference array A:n");
249     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
250 
251     fprintf(stderr, "Reference array U:n");
252     dd_walk_dbl_arr_rowwise(&ref_array_U[0][0], M, M, cb_dbl, cb_dbl_row_end);
253 
254     fprintf(stderr, "Reference array Sigma:n");
255     dd_walk_dbl_arr_rowwise(&ref_array_Sigma[0][0], dd_min(M, N), 1, cb_dbl, cb_dbl_row_end);
256 
257     fprintf(stderr, "Reference array VT:n");
258     dd_walk_dbl_arr_rowwise(&ref_array_VT[0][0], N, N, cb_dbl, cb_dbl_row_end);
LWORK = -1;
261     dgesvd_("A", "A", &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, &WORK_QUERY, &LWORK, &INFO);
262     if (INFO != 0) {
263         if (INFO < 0) {
264             fprintf(stderr, "Error on LAPACK's dgesvd_ query: "the %d-th argument had illegal value"n", INFO);
265         } else {
266             fprintf(stderr, "Error on LAPACK's dgesvd_ query: "DBDSDC didn't converge, updating process failed"n");
267         }
268         return -1;
269     }
270 
271     LWORK = (int) WORK_QUERY;
272     WORK = calloc(LWORK, sizeof(double));
273     if (!WORK) {
274         goto ddt2_fail_sys;
275     }
276 
277     fprintf(stderr, "LAPACK's dgesvd_ query optimal results: LDA %d, LDU %d, LDVT %d, LWORK %d, WORK_QUERY %fn", LDA, LDU, LDVT, LWORK, WORK_QUERY);
278     fprintf(stderr, "Rest of params: M %d, N %dn", M, N);
279 
280     /* Compute SVD. */
281     dgesvd_(&JOBU, &JOBVT, &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, WORK, &LWORK, &INFO);
282     if (INFO != 0) {
283         if (INFO < 0) {
284             fprintf(stderr, "Error on LAPACK's dgesvd_ query: "the %d-th argument had illegal value"n", INFO);
285         } else {
286             fprintf(stderr, "Error on LAPACK's dgesvd_ query: "DBDSDC didn't converge, updating process failed"n");
287         }
288         return -1;
289     }
290 
291     fprintf(stderr, "LAPACK's dgesvd_ SVD completedn");
292 
293     fprintf(stderr, "Result A:n");
294     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
295 
296     fprintf(stderr, "Result U**T:n");
297     dd_walk_dbl_arr_rowwise(U, LDU, M, cb_dbl, cb_dbl_row_end);
298     fprintf(stderr, "Result U:n");
299     dd_walk_dbl_arr_colwise(U, LDU, M, cb_dbl, cb_dbl_row_end);
300 
301 
302     fprintf(stderr, "Result S:n");
303     dd_walk_dbl_arr_rowwise(S, dd_min(M, N), 1, cb_dbl, cb_dbl_row_end);
304 
305     fprintf(stderr, "Result VT:n");
306     dd_walk_dbl_arr_rowwise(VT, LDVT, N, cb_dbl, cb_dbl_row_end);
307 
308     free(WORK);
309     free(A);
310     free(S);
311     free(U);
312     free(VT);
313 
314     return 0;

正确的结果:

peter@xx:~$ ./test4
Reference array A:
1.000000    2.000000    3.000000
2.000000    4.000000    5.000000
3.000000    5.000000    6.000000
Reference array U:
-0.327985   -0.736976   -0.591009
-0.591009   -0.327985   0.736976
-0.736976   0.591009    -0.327985
Reference array Sigma:
11.344814
0.515729
0.170915
Reference array VT:
-0.327985   -0.591009   -0.736976
0.736976    0.327985    -0.591009
-0.591009   0.736976    -0.327985
LAPACK's dgesvd_ query optimal results: LDA 3, LDU 3, LDVT 3, LWORK 201, WORK_QUERY 201.000000
Rest of params: M 3, N 3
LAPACK's dgesvd_ SVD completed
Result A:
-3.741657   0.421793    0.632690
10.643576   1.261481    -0.720622
0.478213    -0.279401   -0.211863
Result U**T:
-0.327985   -0.591009   -0.736976
-0.736976   -0.327985   0.591009
-0.591009   0.736976    -0.327985
Result U:
-0.327985   -0.736976   -0.591009
-0.591009   -0.327985   0.736976
-0.736976   0.591009    -0.327985
Result S:
11.344814
0.515729
0.170915
Result VT:
-0.327985   0.736976    -0.591009
-0.591009   0.327985    0.736976
-0.736976   -0.591009   -0.327985

但不是这个(4x5 矩阵):

39 /* Reference data. */
40 double ref_array_A[4][5] = {
41     { 1, 0, 0, 0, 2 },
42     { 0, 0, 3, 0, 0 },
43     { 0, 0, 0, 0, 0 },
44     { 0, 2, 0, 0, 0 }
45 };
46 
47 double ref_array_U[4][4] = {
48     { 0, 0, 1, 0 },
49     { 0, 1, 0, 0 },
50     { 0, 0, 0, -1 },
51     { 1, 0, 0, 0 }
52 };
53 
54 double ref_array_Sigma[4][5] = {
55     { 2, 0, 0, 0, 0 },
56     { 0, 3, 0, 0, 0 },
57     { 0, 0, 2.236068, 0, 0 },
58     { 0, 0, 0, 0, 0 }
59 };
60 
61 double ref_array_VT[5][5] = {
62     { 0, 1, 0, 0, 0 },
63     { 0, 0, 1, 0, 0 },
64     { 0.447214, 0, 0, 0, 0.894427 },
65     { 0, 0, 0, 1, 0 },
66     { -0.894427, 0, 0, 0, -0.447214 }
67 };
68 
69 /* MATLAB result
70  *
71  * >> A = [ 1 0 0 0 2; 0 0 3 0 0 ; 0 0 0 0 0 ;0 2 0 0 0 ];
72  * >> [U, S, V] = svd(A)
73  *
74  * U =
75  *      0     1     0     0
76  *      1     0     0     0
77  *      0     0     0    -1
78  *      0     0     1     0
79  *
80  * S =
81  *      3.0000      0           0           0           0
82  *      0           2.2361      0           0           0
83  *      0           0           2.0000      0           0
84  *      0           0           0           0           0
85  *
86  * V =
87  *      0           0.4472      0           0           -0.8944
88  *      0           0           1.0000      0           0
89  *      1.0000      0           0           0           0
90  *      0           0           0           1.0000      0
91  *      0           0.8944      0           0           0.4472
92  */
double WORK_QUERY = 0;
206 
207 
208     /* Call dgesvd_ with lwork = -1 to query optimal workspace size. */
209 
210     JOBU = 'A';
211     JOBVT = 'A';
212     M = 4;
213     N = 5;
214     LDA = 4;            /* (out) */
215     LDU = 4;            /* (out) */
216     S = NULL;           /* (don't care) */
217     U = NULL;           /* (don't care) */
218     VT = NULL;          /* (don't care) */
219     LDVT = 5;           /* (out) */
220     WORK = NULL;        /* (out) , because LWORK is 0 do not care */
221     LWORK = 4 * M * N * M *N + 6 * M * N + dd_max(M, N);
222 
223     A = calloc(M * N, sizeof(double));
224     if (!A) {
225         goto ddt2_fail_sys;
226     }
227     for (i = 0; i < M; ++i) {
228         for (j = 0; j < N; ++j) {
229             A[i * N + j] = ref_array_A[i][j];
230         }
231     }
232 
233     S = calloc(M * N, sizeof(double));
234     if (!S) {
235         goto ddt2_fail_sys;
236     }
237 
238     U = calloc(LDU * M, sizeof(double));
239     if (!U) {
240         goto ddt2_fail_sys;
241     }
242 
243     VT = calloc(LDVT * N, sizeof(double));
244     if (!A) {
245         goto ddt2_fail_sys;
246     }
247 
248     fprintf(stderr, "Reference array A:n");
249     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
250 
251     fprintf(stderr, "Reference array U:n");
252     dd_walk_dbl_arr_rowwise(&ref_array_U[0][0], M, M, cb_dbl, cb_dbl_row_end);
253 
254     fprintf(stderr, "Reference array Sigma:n");
255     dd_walk_dbl_arr_rowwise(&ref_array_Sigma[0][0], M, N, cb_dbl, cb_dbl_row_end);
256 
257     fprintf(stderr, "Reference array VT:n");
258     dd_walk_dbl_arr_rowwise(&ref_array_VT[0][0], N, N, cb_dbl, cb_dbl_row_end);
259 
260     LWORK = -1;
261     dgesvd_("A", "A", &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, &WORK_QUERY, &LWORK, &INFO);
if (INFO != 0) {
263         if (INFO < 0) {
264             fprintf(stderr, "Error on LAPACK's dgesvd_ query: "the %d-th argument had illegal value"n", INFO);
265         } else {
266             fprintf(stderr, "Error on LAPACK's dgesvd_ query: "DBDSDC didn't converge, updating process failed"n");
267         }
268         return -1;
269     }
270 
271     LWORK = (int) WORK_QUERY;
272     WORK = calloc(LWORK, sizeof(double));
273     if (!WORK) {
274         goto ddt2_fail_sys;
275     }
276 
277     fprintf(stderr, "LAPACK's dgesvd_ query optimal results: LDA %d, LDU %d, LDVT %d, LWORK %d, WORK_QUERY %fn", LDA, LDU, LDVT, LWORK, WORK_QUERY);
278     fprintf(stderr, "Rest of params: M %d, N %dn", M, N);
279 
280     /* Compute SVD. */
281     dgesvd_(&JOBU, &JOBVT, &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, WORK, &LWORK, &INFO);
282     if (INFO != 0) {
283         if (INFO < 0) {
284             fprintf(stderr, "Error on LAPACK's dgesvd_ query: "the %d-th argument had illegal value"n", INFO);
285         } else {
286             fprintf(stderr, "Error on LAPACK's dgesvd_ query: "DBDSDC didn't converge, updating process failed"n");
287         }
288         return -1;
289     }
290 
291     fprintf(stderr, "LAPACK's dgesvd_ SVD completedn");
292 
293     fprintf(stderr, "Result A:n");
294     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
295 
296     fprintf(stderr, "Result U:n");
297     dd_walk_dbl_arr_rowwise(U, LDU, M, cb_dbl, cb_dbl_row_end);
298 
299     fprintf(stderr, "Result S:n");
300     dd_walk_dbl_arr_rowwise(S, M, N, cb_dbl, cb_dbl_row_end);
301 
302     fprintf(stderr, "Result VT:n");
303     dd_walk_dbl_arr_rowwise(VT, LDVT, N, cb_dbl, cb_dbl_row_end);
304 
305     free(WORK);
306     free(A);
307     free(S);
308     free(U);
309     free(VT);
310 
311     return 0;

不良结果:

peter@xx:~/$ ./test2
Reference array A:
1.000000    0.000000    0.000000    0.000000    2.000000
0.000000    0.000000    3.000000    0.000000    0.000000
0.000000    0.000000    0.000000    0.000000    0.000000
0.000000    2.000000    0.000000    0.000000    0.000000
Reference array U:
0.000000    0.000000    1.000000    0.000000
0.000000    1.000000    0.000000    0.000000
0.000000    0.000000    0.000000    -1.000000
1.000000    0.000000    0.000000    0.000000
Reference array Sigma:
2.000000    0.000000    0.000000    0.000000    0.000000
0.000000    3.000000    0.000000    0.000000    0.000000
0.000000    0.000000    2.236068    0.000000    0.000000
0.000000    0.000000    0.000000    0.000000    0.000000
Reference array VT:
0.000000    1.000000    0.000000    0.000000    0.000000
0.000000    0.000000    1.000000    0.000000    0.000000
0.447214    0.000000    0.000000    0.000000    0.894427
0.000000    0.000000    0.000000    1.000000    0.000000
-0.894427   0.000000    0.000000    0.000000    -0.447214
LAPACK's dgesvd_ query optimal results: LDA 4, LDU 4, LDVT 5, LWORK 300, WORK_QUERY 300.000000
Rest of params: M 4, N 5
LAPACK's dgesvd_ SVD completed
Result A:
-3.000000   -2.000000   0.000000    -1.000000   0.500000
-2.236068   0.000000    0.000000    0.000000    0.000000
0.000000    0.000000    0.000000    0.000000    0.000000
0.000000    0.500000    -0.236068   0.000000    0.000000
Result U:
0.707107    0.000000    0.000000    0.707107
-0.707107   0.000000    -0.000000   0.707107
0.000000    0.000000    1.000000    0.000000
0.000000    1.000000    0.000000    0.000000
Result S:
3.872983    1.732051    0.000000    0.000000    0.000000
0.000000    0.000000    0.000000    0.000000    0.000000
0.000000    0.000000    0.000000    0.000000    0.000000
0.000000    0.000000    0.000000    0.000000    0.000000
Result VT:
0.182574    -0.408248   0.000000    0.000000    -0.894427
0.912871    0.408248    0.000000    0.000000    0.000000
-0.000000   -0.000000   1.000000    0.000000    0.000000
-0.000000   -0.000000   0.000000    1.000000    0.000000
0.365148    -0.816497   0.000000    0.000000    0.447214

在一般矩阵情况下我做错了什么?

该函数dgesvd_期望矩阵按列主顺序排列,而代码以行主样式提供数据:

227     for (i = 0; i < M; ++i) {
228         for (j = 0; j < N; ++j) {
229             A[i * N + j] = ref_array_A[i][j];
230         }
231     }

因此,您的代码实际上计算了 SVD

[ 1 2 0 0 2 ]   [ 1 0 0 0 ] ^ T
[ 0 0 0 0 0 ] = [ 2 0 0 3 ]
[ 0 0 0 0 0 ]   [ 0 0 0 0 ]
[ 0 3 0 0 0 ]   [ 2 0 0 0 ]

这确实产生了大约3.87, 1.73.

这个错误在第一个例子中没有发生,因为矩阵是正方形(M=N)和对称的。

此外,参数S应该只是一维数组(如您的第一个示例所示)。由于您以带有dd_walk_dbl_arr_rowwise(S, M, N, cb_dbl, cb_dbl_row_end);的行大格式打印它,这些值连续出现在第一行......

相关内容

  • 没有找到相关文章

最新更新