用Java将两个矩阵相乘



我目前正在开发一个类来表示矩阵,它表示任何一般的mxn矩阵。我已经算出了加法和标量乘法,但我正在努力发展两个矩阵的乘法。矩阵的数据保存在二维双精度数组中。

方法看起来有点像这样:

public Matrix multiply(Matrix A) {
    ////code
}

它将返回乘积矩阵。右边是乘法。因此,如果我调用A.multiply(B),那么它将返回矩阵AB, B在右边。

我还不需要担心是否在给定的矩阵上定义乘法,我可以假设我将得到正确维数的矩阵。

有没有人知道一个简单的算法,甚至可能在伪代码中执行乘法过程?

矩阵A (l x m)与B (m x n)的乘积在数学上定义为矩阵C (l x n),由以下元素组成:

        m
c_i_j = ∑  a_i_k * b_k_j
       k=1

所以如果你不太想加快速度,你可能会对直接的O(n^3)实现感到满意:

  for (int i=0; i<l; ++i)
    for (int j=0; j<n; ++j)
      for (int k=0; k<m; ++k)
        c[i][j] += a[i][k] * b[k][j]  

如果你想加快速度,你可能想检查其他替代方案,如Strassen算法(参见:Strassen算法)。

然而,要注意——特别是如果你在现代处理器架构上乘小矩阵,速度很大程度上取决于矩阵数据和乘法顺序的安排,以充分利用缓存行。

我强烈怀疑在vm内部是否有任何机会影响这个因素,所以我不确定是否要考虑到这一点。

Java。矩阵乘法。

下面是执行乘法过程的代码。用不同大小的矩阵测试。

public class Matrix {
  /**
   * Matrix multiplication method.
   * @param m1 Multiplicand
   * @param m2 Multiplier
   * @return Product
   */
  public static double[][] multiplyByMatrix(double[][] m1, double[][] m2) {
    int m1ColLength = m1[0].length; // m1 columns length
    int m2RowLength = m2.length;    // m2 rows length
    if (m1ColLength != m2RowLength) return null; // matrix multiplication is not possible
    int mRRowLength = m1.length;    // m result rows length
    int mRColLength = m2[0].length; // m result columns length
    double[][] mResult = new double[mRRowLength][mRColLength];
    for (int i = 0; i < mRRowLength; i++) {     // rows from m1
      for (int j = 0; j < mRColLength; j++) {   // columns from m2
        for (int k = 0; k < m1ColLength; k++) { // columns from m1
        mResult[i][j] += m1[i][k] * m2[k][j];
        }
      }
    }
    return mResult;
  }
  public static String toString(double[][] m) {
    String result = "";
    for (int i = 0; i < m.length; i++) {
      for (int j = 0; j < m[i].length; j++) {
        result += String.format("%11.2f", m[i][j]);
      }
      result += "n";
    }
    return result;
  }
  public static void main(String[] args) {
    // #1
    double[][] multiplicand = new double[][]{
      {3, -1, 2},
      {2,  0, 1},
      {1,  2, 1}
    };
    double[][] multiplier = new double[][]{
      {2, -1, 1},
      {0, -2, 3},
      {3,  0, 1}
    };
    System.out.println("#1n" + toString(multiplyByMatrix(multiplicand, multiplier)));
    // #2
    multiplicand = new double[][]{
      {1,  2, 0},
      {-1, 3, 1},
      {2, -2, 1}
    };
    multiplier = new double[][]{
      {2},
      {-1},
      {1}
    };
    System.out.println("#2n" + toString(multiplyByMatrix(multiplicand, multiplier)));
    // #3
    multiplicand = new double[][]{
      {1, 2, -1},
      {0, 1,  0}
    };
    multiplier = new double[][]{
      {1, 1, 0, 0},
      {0, 2, 1, 1},
      {1, 1, 2, 2}
    };
    System.out.println("#3n" + toString(multiplyByMatrix(multiplicand, multiplier)));
  }
}
输出:

#1
      12.00      -1.00       2.00
       7.00      -2.00       3.00
       5.00      -5.00       8.00
#2
       0.00
      -4.00
       7.00
#3
       0.00       4.00       0.00       0.00
       0.00       2.00       1.00       1.00

在这个答案中,我创建了一个名为Matrix的类,另一个类称为MatrixOperations,它定义了可以在矩阵上执行的各种操作(当然行操作除外)。但是我将从MatrixOperations中提取乘法代码。完整的项目可以在我的GitHub页面上找到。

下面是Matrix类的定义。

package app.matrix;
import app.matrix.util.MatrixException;
public class Matrix {
private double[][] entries;
public void setEntries(double[][] entries) {
    this.entries = entries;
}
private String name;
public double[][] getEntries() {
    return entries;
}
public String getName() {
    return name;
}
public void setName(String name) {
    this.name = name;
}
public class Dimension {
    private int rows;
    private int columns;
    public int getRows() {
        return rows;
    }
    public void setRows(int rows) {
        this.rows = rows;
    }
    public int getColumns() {
        return columns;
    }
    public void setColumns(int columns) {
        this.columns = columns;
    }
    public Dimension(int rows, int columns) {
        this.setRows(rows);
        this.setColumns(columns);
    }
    @Override
    public boolean equals(Object obj) {
        if(obj instanceof Dimension){
            return (this.getColumns() == ((Dimension) obj).getColumns()) && (this.getRows() == ((Dimension) obj).getRows());
        }
        return false;
    }
}
private Dimension dimension;
public Dimension getDimension() {
    return dimension;
}
public void setDimension(Dimension dimension) {
    this.dimension = dimension;
}
public Matrix(int dimension, String name) throws MatrixException {
    if (dimension == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
    else this.setEntries(new double[Math.abs(dimension)][Math.abs(dimension)]);
    this.setDimension(new Dimension(dimension, dimension));
    this.setName(name);
}
public Matrix(int dimensionH, int dimensionV, String name) throws MatrixException {
    if (dimensionH == 0 || dimensionV == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
    else this.setEntries(new double[Math.abs(dimensionH)][Math.abs(dimensionV)]);
    this.setDimension(new Dimension(dimensionH, dimensionV));
    this.setName(name);
}
private static final String OVERFLOW_ITEMS_MSG = "The values are too many for the matrix's specified dimensions";
private static final String ZERO_UNIT_DIMENSION = "Zero cannot be a value for a dimension";
public Matrix(int dimensionH, int dimensionV, String name, double... values) throws MatrixException {
    if (dimensionH == 0 || dimensionV == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
    else if (values.length > dimensionH * dimensionV) throw new MatrixException(Matrix.OVERFLOW_ITEMS_MSG);
    else this.setEntries(new double[Math.abs(dimensionH)][Math.abs(dimensionV)]);
    this.setDimension(new Dimension(dimensionH, dimensionV));
    this.setName(name);
    int iterator = 0;
    int j;
    for (int i = 0; i < dimensionH; i++) {
        j = 0;
        while (j < dimensionV) {
            this.entries[i][j] = values[iterator];
            j++;
            iterator++;
        }
    }
}
public Matrix(Dimension dimension) throws MatrixException {
    this(dimension.getRows(), dimension.getColumns(), null);
}
public static Matrix identityMatrix(int dim) throws MatrixException {
    if (dim == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
    double[] i = new double[dim * dim];
    int constant = dim + 1;
    for (int j = 0; j < i.length; j = j + constant) {
        i[j] = 1.0;
    }
    return new Matrix(dim, dim, null, i);
}
public String toString() {
    StringBuilder builder = new StringBuilder("Matrix "" + (this.getName() == null ? "Null Matrix" : this.getName()) + "": {n");
    for (int i = 0; i < this.getDimension().getRows(); i++) {
        for (int j = 0; j < this.getDimension().getColumns(); j++) {
            if (j == 0) builder.append("t");
            builder.append(this.entries[i][j]);
            if (j != this.getDimension().getColumns() - 1)
                builder.append(", ");
        }
        if (i != this.getDimension().getRows()) builder.append("n");
    }
    builder.append("}");
    return builder.toString();
}
public boolean isSquare() {
    return this.getDimension().getColumns() == this.getDimension().getRows();
}

}

这里是矩阵乘法的代码方法,来自MatrixOperations

public static Matrix multiply(Matrix matrix1, Matrix matrix2) throws MatrixException {
    if (matrix1.getDimension().getColumns() != matrix2.getDimension().getRows())
        throw new MatrixException(MATRIX_MULTIPLICATION_ERROR_MSG);
    Matrix retVal = new Matrix(matrix1.getDimension().getRows(), matrix2.getDimension().getColumns(), matrix1.getName() + " x " + matrix2.getName());

    for (int i = 0; i < matrix1.getDimension().getRows(); i++) {
        for (int j = 0; j < matrix2.getDimension().getColumns(); j++) {
            retVal.getEntries()[i][j] = sum(arrayProduct(matrix1.getEntries()[i], getColumnMatrix(matrix2, j)));
        }
    }
    return retVal;
}

和下面同样是sum、arrayProduct和getColumnMatrix方法的代码

private static double sum(double... values) {
    double sum = 0;
    for (double value : values) {
        sum += value;
    }
    return sum;
}
private static double[] arrayProduct(double[] arr1, double[] arr2) throws MatrixException {
    if (arr1.length != arr2.length) throw new MatrixException("Array lengths must be the same");
    double[] retVal = new double[arr1.length];
    for (int i = 0; i < arr1.length; i++) {
        retVal[i] = arr1[i] * arr2[i];
    }
    return retVal;
}

private static double[] getColumnMatrix(Matrix matrix, int col) {
    double[] ret = new double[matrix.getDimension().getRows()];
    for (int i = 0; i < matrix.getDimension().getRows(); i++) {
        ret[i] = matrix.getEntries()[i][col];
    }
    return ret;
}

为多个任意维度数组尝试此代码并打印它。我觉得这样更简单,任何人都能理解。

public class Test {
  public static void main(String[] args) {
    int[][] array1 = {
      {1, 4, -2},
      {3, 5, -6},
      {4, 5, 2}
    };
    int[][] array2 = {
      {5, 2, 8, -1},
      {3, 6, 4, 5},
      {-2, 9, 7, -3}
    };
    Test test = new Test();
    test.printArray(test.multiplication(array1, array2));
  }
  private int[][] multiplication(int[][] array1, int[][] array2) {
    int r1, r2, c1, c2;
    r1 = array1.length;
    c1 = array1[0].length;
    r2 = array2.length;
    c2 = array2[0].length;
    int[][] result;
    if (c1 != r2) {
      System.out.println("Error!");
      result = new int[0][0];
    } else {
      result = new int[r1][c2];
      for (int i = 0; i < r1; i++) { //2
        for (int j = 0; j < c2; j++) { //4
          for (int k = 0; k < c1; k++) {
            result[i][j] += array1[i][k] * array2[k][j];
          }
        }
      }
    }
    return result;
  }
  private void printArray(int[][] array) {
    for (int[] arr : array) {
      for (int element : arr) {
        System.out.print(element + " ");
      }
      System.out.println();
    }
  }
}

相关内容

  • 没有找到相关文章

最新更新