r-如何求所有反对角线的和

  • 本文关键字:对角线 何求所 r matrix
  • 更新时间 :
  • 英文 :


我有一个矩阵M:

n = 3    
x=c(0.85, 0.1, 0.05)
M <- matrix(NA, n, n); 
for(i in 1:n){
for(j in 1:n){
M[i,j] = x[i] * x[j]
}}
#       [,1]  [,2]   [,3]
# [1,] 0.7225 0.085 0.0425
# [2,] 0.0850 0.010 0.0050
# [3,] 0.0425 0.005 0.0025

我需要找到所有反对角线的和,包括M[1,1]和M[n,n]。我的尝试是

d <-matrix(c(0, 1, 2, 1, 2, 3, 2, 3, 4), n)
tapply(M, d, sum)
0      1      2      3      4 
0.7225 0.1700 0.0950 0.0100 0.0025 

结果对我来说是正确的。

问题如何定义矩阵d的条目?可以是列(M(和行(M(上的函数。

正如您在问题中提到的,row(M)col(M)是可以使用的,尽管它们的行/列以1开头,而不是以零开头,所以您需要减去2(每个1(,给出:

tapply(M, row(M) + col(M) - 2, sum)
#     0      1      2      3      4 
#0.7225 0.1700 0.0950 0.0100 0.0025

首先注意,outer可以生成矩阵d,而无需显式列出其元素。

matrix(c(0, 1, 2, 1, 2, 3, 2, 3, 4), 3)
#>      [,1] [,2] [,3]
#> [1,]    0    1    2
#> [2,]    1    2    3
#> [3,]    2    3    4
outer(0:2, 0:2, `+`)
#>      [,1] [,2] [,3]
#> [1,]    0    1    2
#> [2,]    1    2    3
#> [3,]    2    3    4

创建于2022-03-24由reprex包(v2.0.1(

并在函数中使用它。

sumAntiDiag <- function(M){
nr <- nrow(M)
nc <- ncol(M)
d <- outer(seq.int(nr), seq.int(nc), `+`)
tapply(M, d, sum)
}
n <- 3    
x <- c(0.85, 0.1, 0.05)
M <- matrix(NA, n, n); 
for(i in 1:n){
for(j in 1:n){
M[i,j] = x[i] * x[j]
}}
sumAntiDiag(M)
#>      2      3      4      5      6 
#> 0.7225 0.1700 0.0950 0.0100 0.0025

创建于2022-03-24由reprex包(v2.0.1(

你可以做:

sapply(seq(3), function(x) seq(3) + x - 2)
#>      [,1] [,2] [,3]
#> [1,]    0    1    2
#> [2,]    1    2    3
#> [3,]    2    3    4

或者更普遍地说,

anti_diagonal <- function(n) sapply(seq(n), function(x) seq(n) + x - 2)

例如:

anti_diagonal(6)
#>      [,1] [,2] [,3] [,4] [,5] [,6]
#> [1,]    0    1    2    3    4    5
#> [2,]    1    2    3    4    5    6
#> [3,]    2    3    4    5    6    7
#> [4,]    3    4    5    6    7    8
#> [5,]    4    5    6    7    8    9
#> [6,]    5    6    7    8    9   10

尝试下面的代码,使用基于R的embed定义函数f,即

f <- function(n) embed(seq(2 * n - 1) - 1, n)[, n:1]

使得

> f(3)
[,1] [,2] [,3]
[1,]    0    1    2
[2,]    1    2    3
[3,]    2    3    4
> f(4)
[,1] [,2] [,3] [,4]
[1,]    0    1    2    3
[2,]    1    2    3    4
[3,]    2    3    4    5
[4,]    3    4    5    6
> f(5)
[,1] [,2] [,3] [,4] [,5]
[1,]    0    1    2    3    4
[2,]    1    2    3    4    5
[3,]    2    3    4    5    6
[4,]    3    4    5    6    7
[5,]    4    5    6    7    8

您可以使用sequence:

function(n) matrix(sequence(rep(n, n), seq(n) - 1), nrow = n)

输出

f <- function(n) matrix(sequence(rep(n, n), seq(n) - 1), nrow = n)
f(3)
[,1] [,2] [,3]
[1,]    0    1    2
[2,]    1    2    3
[3,]    2    3    4
f(5)
[,1] [,2] [,3] [,4] [,5]
[1,]    0    1    2    3    4
[2,]    1    2    3    4    5
[3,]    2    3    4    5    6
[4,]    3    4    5    6    7
[5,]    4    5    6    7    8

使用索引而不是tapply会稍微加快速度。或Rcpp:

sumdiags <- function(mat, minor = TRUE) {
m <- ncol(mat)

if (minor) {
n <- nrow(mat)
lens <- c(pmin(1:n, m), pmin((m - 1L):1, n))
c(mat[1], diff(cumsum(mat[sequence(lens, c(1:n, seq(2L*n, by = n, length.out = m - 1L)), n - 1L)])[cumsum(lens)]))
} else {
Recall(mat[,m:1])
}
}
# compare to tapply solution
sumdiags2 <- function(mat, minor = TRUE) {
if (minor) {
as.numeric(tapply(mat, row(mat) + col(mat), sum))
} else {
Recall(mat[,ncol(mat):1])
}
}
# or Rcpp
Rcpp::cppFunction('NumericVector sumdiagsRcpp(const NumericMatrix& mat) {
const int n = mat.nrow();
const int m = mat.ncol();
NumericVector x (n + m - 1);
for(int row = 0; row < n; row++) {
for(int col = 0; col < m; col++) {
x[row + col] += mat(row, col);
}
}
return x;
}')
# OP data
x <- c(0.85, 0.1, 0.05)
m <- outer(x, x)
sumdiags(m)
#> [1] 0.7225 0.1700 0.0950 0.0100 0.0025
sumdiags2(m)
#> [1] 0.7225 0.1700 0.0950 0.0100 0.0025
sumdiagsRcpp(m)
#> [1] 0.7225 0.1700 0.0950 0.0100 0.0025
# bigger matrix for benchmarking
m <- matrix(runif(1e6), 1e3)
microbenchmark::microbenchmark(sumdiags = sumdiags(m),
sumdiags2 = sumdiags2(m),
sumdiagsRcpp = sumdiagsRcpp(m),
check = "equal")
#> Unit: milliseconds
#>         expr       min        lq      mean    median      uq        max neval
#>     sumdiags  9.985302 10.266350 13.686723 10.803401 17.5274  22.387601   100
#>    sumdiags2 55.790402 65.140051 78.763478 67.120051 70.4165 183.936801   100
#> sumdiagsRcpp  2.192201  2.378651  2.599326  2.631751  2.7050   4.038301   100

相关内容

  • 没有找到相关文章

最新更新