r-Rcpp中的这一点正确吗

  • 本文关键字:这一点 r-Rcpp r rcpp
  • 更新时间 :
  • 英文 :


我想比较每一列,并在计算后返回所有结果。我试着写代码,但结果并不合理。因为如果矩阵中有5列,结果的数量将是5*4/2=10,而不是5。我认为问题出在代码中的m。我不知道这是否正确。谢谢

library(Rcpp)
sourceCpp(code='
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h> 
double KS(arma::colvec x, arma::colvec y) {
int n = x.n_rows;
arma::colvec w = join_cols(x, y);
arma::uvec z = arma::sort_index(w);
w.fill(-1); w.elem( find(z <= n-1) ).ones();
return max(abs(cumsum(w)))/n;
}
// [[Rcpp::export]]
Rcpp::NumericVector K_S(arma::mat mt) {
int n = mt.n_cols;
int m = 1;
Rcpp::NumericVector results(n);
for (int i = 0; i < n-1; i++) {
for (int j = i+1; j < n; j++){
arma::colvec x=mt.col(i);
arma::colvec y=mt.col(j);
results[m] = KS(x, y);
m ++;
}
}
return results;
}
')
set.seed(1)
mt <- matrix(rnorm(400*5), ncol=5)
result <- K_S(t(mt))
> result
[1] 0.0000 0.1050 0.0675 0.0475 0.0650

您有几个小错误。在修复它的过程中,我刚刚填充了一个类似的n乘n矩阵的中间版本,这使得索引错误显而易见。返回arma::rowvec也有助于解决可能的越界索引错误(默认情况下是错误(,但最后(在这种情况下!!(实际上可以只增加std::vector

代码

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
double KS(arma::colvec x, arma::colvec y) {
int n = x.n_rows;
arma::colvec w = join_cols(x, y);
arma::uvec z = arma::sort_index(w);
w.fill(-1); w.elem( find(z <= n-1) ).ones();
return max(abs(cumsum(w)))/n;
}
// [[Rcpp::export]]
std::vector<double> K_S(arma::mat mt) {
int n = mt.n_cols;
std::vector<double> res;
for (int i = 0; i < n; i++) {
for (int j = i+1; j < n; j++){
arma::colvec x=mt.col(i);
arma::colvec y=mt.col(j);
res.push_back(KS(x, y));
}
}
return res;
}
/*** R
set.seed(1)
mt <- matrix(rnorm(400*5), ncol=5)
result <- K_S(mt)
result
*/

输出

> Rcpp::sourceCpp("~/git/stackoverflow/73916783/answer.cpp")
> set.seed(1)
> mt <- matrix(rnorm(400*5), ncol=5)
> result <- K_S(mt)
> result
[1] 0.1050 0.0675 0.0475 0.0650 0.0500 0.0775 0.0575 0.0500 0.0475 0.0600
> 

相关内容

最新更新