我需要创建一个Matlab mex函数,该函数将获取一个输入矩阵并返回矩阵对角线。
输入:
1 2 3
4 5 6
预期输出:
1 2 3 0 0 0
0 0 0 4 5 6
我的问题是,由于Matlab按列而不是按行读取矩阵,所以我的mex函数给出了错误的输出。
电流输出:
1 4 0 0 0 0
0 0 2 5 0 0
0 0 0 0 3 6
您将如何更改我的代码以逐行读取输入矩阵,从而获得正确的输出?
我的代码如下:
#include <matrix.h>
#include <mex.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
mxArray *a_in, *b_out;
const mwSize *dims;
double *a, *b;
int rows, cols;
// Input array:
a_in = mxDuplicateArray(prhs[0]);
// Get dimensions of input array:
dims = mxGetDimensions(prhs[0]);
rows = (int) dims[0];
cols = (int) dims[1];
// Output array:
if(rows == cols){
b_out = plhs[0] = mxCreateDoubleMatrix(rows, rows*cols, mxREAL);
}
else{
b_out = plhs[0] = mxCreateDoubleMatrix(cols, rows*cols, mxREAL);
}
// Access the contents of the input and output arrays:
a = mxGetPr(a_in);
b = mxGetPr(b_out);
// Compute exdiag function of the input array
int count = 0;
for (int i = 0; i < rows; i++) {
for(int j = 0; j<cols;j++){
if(rows == cols){
b[rows*count+count/rows] = a[j + rows * i];
count++;
}
else if(rows < cols){
b[cols*count+count/rows] = a[j + cols * i];
count++;
}
else if(rows>cols){
b[cols*count+count/rows] = a[j + cols * i];
count++;
}
}
}
}
在循环中,i
是行索引,j
是列索引。您执行a[j + rows * i]
,将两个索引混合。MATLAB按列存储数据,因此需要执行a[i + rows * j]
才能正确读取输入矩阵。
对于索引输出,您希望行保持为i
,并且希望列为i * cols + j
:
b[i + rows * (i * cols + j)] = a[i + rows * j];
请注意,您不需要执行a_in = mxDuplicateArray(prhs[0])
,因为您没有写入a_in
。您可以直接访问prhs[0]
矩阵,如果需要别名,也可以执行a_in = prhs[0]
。
此外,如果数组非常大,将数组大小强制转换为int
也会引起问题。对于数组大小和索引,最好使用mwSize
和mwIndex
。
最后,您应该始终检查输入数组的类型,如果您得到的数组不是双精度的,则可能会导致读取越界错误。
这是我的代码:
#include <matrix.h>
#include <mex.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
mwSize const* dims;
double *a, *b;
mwSize rows, cols;
if (!mxIsDouble(prhs[0])) {
mexErrMsgTxt("Input must be doubles");
}
// Get dimensions of input array:
dims = mxGetDimensions(prhs[0]);
rows = dims[0];
cols = dims[1];
// Output array:
plhs[0] = mxCreateDoubleMatrix(rows, rows*cols, mxREAL);
// Access the contents of the input and output arrays:
a = mxGetPr(prhs[0]);
b = mxGetPr(plhs[0]);
// Compute exdiag function of the input array
for (mwIndex i = 0; i < rows; i++) {
for (mwIndex j = 0; j < cols; j++) {
b[i + rows * (i * cols + j)] = a[i + rows * j];
}
}
}