如何逐行读取Matlab mex函数的输入矩阵



我需要创建一个Matlab mex函数,该函数将获取一个输入矩阵并返回矩阵对角线。

输入:

1 2 3
4 5 6

预期输出:

1 2 3 0 0 0
0 0 0 4 5 6

我的问题是,由于Matlab按列而不是按行读取矩阵,所以我的mex函数给出了错误的输出。

电流输出:

1 4 0 0 0 0
0 0 2 5 0 0
0 0 0 0 3 6

您将如何更改我的代码以逐行读取输入矩阵,从而获得正确的输出?

我的代码如下:

#include <matrix.h>
#include <mex.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
mxArray *a_in, *b_out;
const mwSize *dims;
double *a, *b; 
int rows, cols;
// Input array:
a_in = mxDuplicateArray(prhs[0]);
// Get dimensions of input array:
dims = mxGetDimensions(prhs[0]);
rows = (int) dims[0];
cols = (int) dims[1];
// Output array:
if(rows == cols){
b_out = plhs[0] = mxCreateDoubleMatrix(rows, rows*cols, mxREAL);
}
else{
b_out = plhs[0] = mxCreateDoubleMatrix(cols, rows*cols, mxREAL);
}
// Access the contents of the input and output arrays:
a = mxGetPr(a_in);
b = mxGetPr(b_out);

// Compute exdiag function of the input array
int count = 0;
for (int i = 0; i < rows; i++) {
for(int j = 0; j<cols;j++){
if(rows == cols){
b[rows*count+count/rows] = a[j + rows * i];
count++;
}
else if(rows < cols){
b[cols*count+count/rows] = a[j + cols * i];
count++;
}
else if(rows>cols){
b[cols*count+count/rows] = a[j + cols * i];
count++;
}
}
}
}

在循环中,i是行索引,j是列索引。您执行a[j + rows * i],将两个索引混合。MATLAB按列存储数据,因此需要执行a[i + rows * j]才能正确读取输入矩阵。

对于索引输出,您希望行保持为i,并且希望列为i * cols + j:

b[i + rows * (i * cols + j)] = a[i + rows * j];

请注意,您不需要执行a_in = mxDuplicateArray(prhs[0]),因为您没有写入a_in。您可以直接访问prhs[0]矩阵,如果需要别名,也可以执行a_in = prhs[0]

此外,如果数组非常大,将数组大小强制转换为int也会引起问题。对于数组大小和索引,最好使用mwSizemwIndex

最后,您应该始终检查输入数组的类型,如果您得到的数组不是双精度的,则可能会导致读取越界错误。


这是我的代码:

#include <matrix.h>
#include <mex.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
mwSize const* dims;
double *a, *b;
mwSize rows, cols;

if (!mxIsDouble(prhs[0])) {
mexErrMsgTxt("Input must be doubles");
}

// Get dimensions of input array:
dims = mxGetDimensions(prhs[0]);
rows = dims[0];
cols = dims[1];

// Output array:
plhs[0] = mxCreateDoubleMatrix(rows, rows*cols, mxREAL);
// Access the contents of the input and output arrays:
a = mxGetPr(prhs[0]);
b = mxGetPr(plhs[0]);

// Compute exdiag function of the input array
for (mwIndex i = 0; i < rows; i++) {
for (mwIndex j = 0; j < cols; j++) {
b[i + rows * (i * cols + j)] = a[i + rows * j];
}
}
}

相关内容

  • 没有找到相关文章

最新更新