如何在 Scala Breeze 中进行 X * diag(Y)



如何在 Scala Breeze 中做X * diag(Y)? 例如,X可以是CSCMatrixY可以是DenseVector

在 MATLAB 语法中,这将是:

X * spdiags(0, Y, N, N )

或:

X .* repmat( Y', K, 0 )

在 SciPy 语法中,这将是一个"广播乘法":

Y * X

如何在斯卡拉微风中做X * diag(Y)

我写了自己的稀疏对角线方法,最后是密集/稀疏乘法。

像这样使用:

val N = 100000
val K = 100
val A = DenseMatrix.rand(N,K)
val b = DenseVector.rand(N)
val c = MatrixHelper.spdiag(b)
val d = MatrixHelper.mul( A.t, c )

以下是 spdiag 和 mul 的实现:

// Copyright Hugh Perkins 2012
// You can use this under the terms of the Apache Public License 2.0
// http://www.apache.org/licenses/LICENSE-2.0
package root
import breeze.linalg._
object MatrixHelper {
   // it's only efficient to put the sparse matrix on the right hand side, since 
   // it is a column-sparse matrix
   def mul( A: DenseMatrix[Double], B: CSCMatrix[Double] ) : DenseMatrix[Double] = {
      val resultRows = A.rows
      val resultCols = B.cols
      var row = 0
      val result = DenseMatrix.zeros[Double](resultRows, resultCols )
      while( row < resultRows ) {
         var col = 0
         while( col < resultCols ) {
            val rightRowStartIndex = B.colPtrs(col)
            val rightRowEndIndex = B.colPtrs(col + 1) - 1
            val numRightRows = rightRowEndIndex - rightRowStartIndex + 1
            var ri = 0
            var sum = 0.
            while( ri < numRightRows ) {
               val inner = B.rowIndices(rightRowStartIndex + ri)
               val rightValue = B.data(rightRowStartIndex + ri)
               sum += A(row,inner) * rightValue
               ri += 1
            }
            result(row,col) = sum
            col += 1
         }
         row += 1
      }
      result
   }   
   def spdiag( a: Tensor[Int,Double] ) : CSCMatrix[Double] = {
      val size = a.size
      val result = CSCMatrix.zeros[Double](size,size)
      result.reserve(a.size)
      var i = 0
      while( i < size ) {
         result.rowIndices(i) = i
         result.colPtrs(i) = i
         result.data(i) = i
         //result(i,i) = a(i)
         i += 1
      }
      //result.activeSize = size
      result.colPtrs(i) = i
      result
   }
}

相关内容

  • 没有找到相关文章

最新更新