Java中的矩阵乘法

Matrix Multiplication in Java

1.概述

在本教程中,我们将了解如何在Java中将两个矩阵相乘。

由于语言本身并不存在矩阵概念,因此我们将自己实现,并且还将与一些库一起工作,以了解它们如何处理矩阵乘法。

最后,我们将对我们探索的不同解决方案进行一些基准测试,以便确定最快的解决方案。

2.例子

让我们从建立一个示例开始,我们将在整个教程中进行引用。

首先,我们将想象一个3×2的矩阵:

 srcset=

 srcset=

现在让我们想象第二个矩阵,这次是两行乘四列:

 srcset=

 srcset=

然后,将第一个矩阵乘以第二个矩阵,将得到一个3×4的矩阵:

 srcset=

 srcset=

提醒一下,此结果是通过使用以下公式计算所得矩阵的每个像元获得的:

 width=

 width=

其中矩阵A的行数为,顺式矩阵Band的列数为矩阵A的列数,必须与矩阵B的行数匹配。

3.矩阵乘法

3.1。 自己实施

让我们从我们自己的矩阵实现开始。

我们将使其保持简单,仅使用二维双精度数组:

1
2
3
4
5
6
7
8
9
10
double[][] firstMatrix = {
  new double[]{1d, 5d},
  new double[]{2d, 3d},
  new double[]{1d, 7d}
};

double[][] secondMatrix = {
  new double[]{1d, 2d, 3d, 7d},
  new double[]{5d, 2d, 8d, 1d}
};

这些是我们示例的两个矩阵。 让我们创建一个期望的乘积结果:

1
2
3
4
5
double[][] expected = {
  new double[]{26d, 12d, 43d, 12d},
  new double[]{17d, 10d, 30d, 17d},
  new double[]{36d, 16d, 59d, 14d}
};

现在已经完成了所有设置,让我们实现乘法算法。 我们将首先创建一个空的结果数组,并遍历其单元格以在每个单元格中存储期望值:

1
2
3
4
5
6
7
8
9
10
11
double[][] multiplyMatrices(double[][] firstMatrix, double[][] secondMatrix) {
    double[][] result = new double[firstMatrix.length][secondMatrix[0].length];

    for (int row = 0; row < result.length; row++) {
        for (int col = 0; col < result[row].length; col++) {
            result[row][col] = multiplyMatricesCell(firstMatrix, secondMatrix, row, col);
        }
    }

    return result;
}

最后,让我们实现单个单元的计算。 为了实现这一点,我们将使用示例演示中前面显示的公式:

1
2
3
4
5
6
7
double multiplyMatricesCell(double[][] firstMatrix, double[][] secondMatrix, int row, int col) {
    double cell = 0;
    for (int i = 0; i < secondMatrix.length; i++) {
        cell += firstMatrix[row][i] * secondMatrix[i][col];
    }
    return cell;
}

最后,让我们检查算法的结果是否符合我们的预期结果:

1
2
double[][] actual = multiplyMatrices(firstMatrix, secondMatrix);
assertThat(actual).isEqualTo(expected);

3.2。 EJML

我们将要看的第一个库是EJML,它表示Efficient Java Matrix Library。 在撰写本教程时,它是最近更新的Java矩阵库之一。 其目的是在计算和内存使用方面尽可能地高效。

我们必须将依赖项添加到pom.xml中的库中:

1
2
3
4
5
<dependency>
    <groupId>org.ejml</groupId>
    <artifactId>ejml-all</artifactId>
    <version>0.38</version>
</dependency>

我们将使用与以前几乎相同的模式:根据我们的示例创建两个矩阵,并检查它们相乘的结果是否是我们先前计算的结果。

因此,让我们使用EJML创建矩阵。 为了实现这一点,我们将使用该库提供的SimpleMatrix类。

它可以将二维double数组作为其构造函数的输入:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
SimpleMatrix firstMatrix = new SimpleMatrix(
  new double[][] {
    new double[] {1d, 5d},
    new double[] {2d, 3d},
    new double[] {1d ,7d}
  }
);

SimpleMatrix secondMatrix = new SimpleMatrix(
  new double[][] {
    new double[] {1d, 2d, 3d, 7d},
    new double[] {5d, 2d, 8d, 1d}
  }
);

现在,让我们定义乘法的期望矩阵:

1
2
3
4
5
6
7
SimpleMatrix expected = new SimpleMatrix(
  new double[][] {
    new double[] {26d, 12d, 43d, 12d},
    new double[] {17d, 10d, 30d, 17d},
    new double[] {36d, 16d, 59d, 14d}
  }
);

现在我们已经完成了所有的设置,让我们看看如何将两个矩阵相乘。 TheSimpleMatrix类提供amult()方法,该方法将另一个SimpleMatrix作为参数并返回两个矩阵的乘法:

1
SimpleMatrix actual = firstMatrix.mult(secondMatrix);

让我们检查获得的结果是否与预期的结果相符。

AsSimpleMatrix不会覆盖equals()方法,因此我们不能依靠它来进行验证。 但是,它提供了另一种选择:theisIdentical()方法,该方法不仅采用了另一个矩阵参数,而且还采用了双倍的容错能力,以忽略由于双精度带来的微小差异:

1
assertThat(actual).matches(m -> m.isIdentical(expected, 0d));

到此为止,使用EJML库进行矩阵乘法。 让我们看看其他提供的东西。

3.3。 ND4J

现在让我们尝试ND4J库。 ND4J是一个计算库,是deeplearning4j项目的一部分。 ND4J尤其提供矩阵计算功能。

首先,我们必须获取库依赖:

1
2
3
4
5
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native</artifactId>
    <version>1.0.0-beta4</version>
</dependency>

请注意,我们此处使用的是Beta版本,因为GA版本似乎存在一些错误。

为了简洁起见,我们将不重写二维双精度数组,而只关注它们在每个库中的使用方式。 因此,对于ND4J,我们必须创建一个INDArray。 为此,我们将调用theNd4j.create()工厂方法,并将其传递给表示我们矩阵的double数组:

1
INDArray matrix = Nd4j.create(/* a two dimensions double array */);

与上一节一样,我们将创建三个矩阵:两个将要相乘,另一个是预期的结果。

之后,我们实际上想使用INDArray.mmul()方法在前两个矩阵之间进行乘法:

1
INDArray actual = firstMatrix.mmul(secondMatrix);

然后,我们再次检查实际结果是否与预期结果相符。 这次我们可以依靠相等性检查:

1
assertThat(actual).isEqualTo(expected);

这演示了如何使用ND4J库进行矩阵计算。


3.4。 Apache Commons

现在让我们谈谈Apache Commons Math3模块,该模块为我们提供了包括矩阵处理在内的数学计算。

同样,我们必须在pom.xml中指定依赖项:

1
2
3
4
5
<dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-math3</artifactId>
    <version>3.6.1</version>
</dependency>

设置完成后,我们可以使用RealMatrix接口及其itsArray2DRowRealMatrix实现来创建我们通常的矩阵。 实现类的构造函数将二维双精度数组作为其参数:

1
RealMatrix matrix = new Array2DRowRealMatrix(/* a two dimensions double array */);

对于矩阵乘法,RealMatrix接口提供了采用另一个RealMatrix参数的multiply()方法:

1
RealMatrix actual = firstMatrix.multiply(secondMatrix);

我们最终可以验证结果是否等于我们的预期:

1
assertThat(actual).isEqualTo(expected);

让我们看看下一个库!

3.5。 LA4J

这个名字叫LA4J,代表Java的线性代数。

让我们也为这个添加依赖项:

1
2
3
4
5
<dependency>
    <groupId>org.la4j</groupId>
    <artifactId>la4j</artifactId>
    <version>0.6.0</version>
</dependency>

现在,LA4J的工作原理与其他库非常相似。 它提供带有aBasic2DMatrix实现的aMatrix接口,该实现将二维双精度数组作为输入:

1
Matrix matrix = new Basic2DMatrix(/* a two dimensions double array */);

与Apache Commons Math3模块一样,乘法方法是multimultily()并采用anotherMatrix作为其参数:

1
Matrix actual = firstMatrix.multiply(secondMatrix);

再一次,我们可以检查结果是否符合我们的期望:

1
assertThat(actual).isEqualTo(expected);

现在让我们看一下我们的最后一个库:Colt。

3.6。 小马

Colt是CERN开发的图书馆。 它提供了支持高性能科学和技术计算的功能。

与以前的库一样,我们必须获得正确的依赖关系:

1
2
3
4
5
<dependency>
    <groupId>colt</groupId>
    <artifactId>colt</artifactId>
    <version>1.2.0</version>
</dependency>

为了使用Colt创建矩阵,我们必须使用DoubleFactory2D类。 它带有三个工厂实例:密集,稀疏和行压缩。 每个都经过优化以创建匹配的矩阵类型。

为了我们的目的,我们将使用密集实例。 这次,调用ismake()的方法将再次使用一个二维双精度数组,从而产生一个DoubleMatrix2D对象:

1
DoubleMatrix2D matrix = doubleFactory2D.make(/* a two dimensions double array */);

实例化矩阵后,我们将要对其进行乘法运算。 这次,矩阵对象上没有方法可以做到这一点。 我们必须创建一个Algebra类的实例,该实例具有amult()方法,并使用两个矩阵作为参数:

1
2
Algebra algebra = new Algebra();
DoubleMatrix2D actual = algebra.mult(firstMatrix, secondMatrix);

然后,我们可以将实际结果与预期结果进行比较:

1
assertThat(actual).isEqualTo(expected);

4.标杆管理

既然我们已经完成了矩阵乘法的各种可能性的探索,那么让我们检查一下哪个是性能最高的。

4.1。 小矩阵

让我们从小矩阵开始。 在这里,3×2和2×4矩阵。

为了实施性能测试,我们将使用JMH基准测试库

1
2
3
4
5
6
7
8
9
10
11
12
public static void main(String[] args) throws Exception {
    Options opt = new OptionsBuilder()
      .include(MatrixMultiplicationBenchmarking.class.getSimpleName())
      .mode(Mode.AverageTime)
      .forks(2)
      .warmupIterations(5)
      .measurementIterations(10)
      .timeUnit(TimeUnit.MICROSECONDS)
      .build();

    new Runner(opt).run();
}

这样,JMH将为每个用@Benchmark注释的方法进行两次完整运行,每个运行都有五次预热迭代(未计入平均计算)和十次测量。 至于测量,它将收集不同库的平均执行时间(以微秒为单位)。

然后,我们必须创建一个包含数组的状态对象:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@State(Scope.Benchmark)
public class MatrixProvider {
    private double[][] firstMatrix;
    private double[][] secondMatrix;

    public MatrixProvider() {
        firstMatrix =
          new double[][] {
            new double[] {1d, 5d},
            new double[] {2d, 3d},
            new double[] {1d ,7d}
          };

        secondMatrix =
          new double[][] {
            new double[] {1d, 2d, 3d, 7d},
            new double[] {5d, 2d, 8d, 1d}
          };
    }
}

这样,我们可以确保数组初始化不属于基准测试。 之后,我们仍然必须使用MatrixProvider对象作为数据源来创建执行矩阵乘法的方法。 我们将在这里不再重复代码,因为我们之前看到了每个库。

最后,我们将使用main方法运行基准测试过程。 这给我们以下结果:

1
2
3
4
5
6
7
Benchmark                                                           Mode  Cnt   Score   Error  Units
MatrixMultiplicationBenchmarking.apacheCommonsMatrixMultiplication  avgt   20   1,008 ± 0,032  us/op
MatrixMultiplicationBenchmarking.coltMatrixMultiplication           avgt   20   0,219 ± 0,014  us/op
MatrixMultiplicationBenchmarking.ejmlMatrixMultiplication           avgt   20   0,226 ± 0,013  us/op
MatrixMultiplicationBenchmarking.homemadeMatrixMultiplication       avgt   20   0,389 ± 0,045  us/op
MatrixMultiplicationBenchmarking.la4jMatrixMultiplication           avgt   20   0,427 ± 0,016  us/op
MatrixMultiplicationBenchmarking.nd4jMatrixMultiplication           avgt   20  12,670 ± 2,582  us/op

如我们所见,EJML和Colt的性能非常好,每次操作大约只有五分之一微秒,而ND4j的性能较低,每次操作只有十微秒多一点。 其他库之间有表演。

另外,值得注意的是,当将预热迭代次数从5增加到10时,所有库的性能都会提高。

4.2。 大型矩阵

现在,如果我们采用更大的矩阵(例如3000×3000)会发生什么? 为了检查发生了什么,让我们首先创建另一个状态类,提供该大小的生成矩阵:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@State(Scope.Benchmark)
public class BigMatrixProvider {
    private double[][] firstMatrix;
    private double[][] secondMatrix;

    public BigMatrixProvider() {}

    @Setup
    public void setup(BenchmarkParams parameters) {
        firstMatrix = createMatrix();
        secondMatrix = createMatrix();
    }

    private double[][] createMatrix() {
        Random random = new Random();

        double[][] result = new double[3000][3000];
        for (int row = 0; row < result.length; row++) {
            for (int col = 0; col < result[row].length; col++) {
                result[row][col] = random.nextDouble();
            }
        }
        return result;
    }
}

如我们所见,我们将创建3000×3000的二维双精度数组,其中填充了随机实数。

现在创建基准测试类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
public class BigMatrixMultiplicationBenchmarking {
    public static void main(String[] args) throws Exception {
        Map<String, String> parameters = parseParameters(args);

        ChainedOptionsBuilder builder = new OptionsBuilder()
          .include(BigMatrixMultiplicationBenchmarking.class.getSimpleName())
          .mode(Mode.AverageTime)
          .forks(2)
          .warmupIterations(10)
          .measurementIterations(10)
          .timeUnit(TimeUnit.SECONDS);

        new Runner(builder.build()).run();
    }

    @Benchmark
    public Object homemadeMatrixMultiplication(BigMatrixProvider matrixProvider) {
        return HomemadeMatrix
          .multiplyMatrices(matrixProvider.getFirstMatrix(), matrixProvider.getSecondMatrix());
    }

    @Benchmark
    public Object ejmlMatrixMultiplication(BigMatrixProvider matrixProvider) {
        SimpleMatrix firstMatrix = new SimpleMatrix(matrixProvider.getFirstMatrix());
        SimpleMatrix secondMatrix = new SimpleMatrix(matrixProvider.getSecondMatrix());

        return firstMatrix.mult(secondMatrix);
    }

    @Benchmark
    public Object apacheCommonsMatrixMultiplication(BigMatrixProvider matrixProvider) {
        RealMatrix firstMatrix = new Array2DRowRealMatrix(matrixProvider.getFirstMatrix());
        RealMatrix secondMatrix = new Array2DRowRealMatrix(matrixProvider.getSecondMatrix());

        return firstMatrix.multiply(secondMatrix);
    }

    @Benchmark
    public Object la4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
        Matrix firstMatrix = new Basic2DMatrix(matrixProvider.getFirstMatrix());
        Matrix secondMatrix = new Basic2DMatrix(matrixProvider.getSecondMatrix());

        return firstMatrix.multiply(secondMatrix);
    }

    @Benchmark
    public Object nd4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
        INDArray firstMatrix = Nd4j.create(matrixProvider.getFirstMatrix());
        INDArray secondMatrix = Nd4j.create(matrixProvider.getSecondMatrix());

        return firstMatrix.mmul(secondMatrix);
    }

    @Benchmark
    public Object coltMatrixMultiplication(BigMatrixProvider matrixProvider) {
        DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense;

        DoubleMatrix2D firstMatrix = doubleFactory2D.make(matrixProvider.getFirstMatrix());
        DoubleMatrix2D secondMatrix = doubleFactory2D.make(matrixProvider.getSecondMatrix());

        Algebra algebra = new Algebra();
        return algebra.mult(firstMatrix, secondMatrix);
    }
}

当运行此基准测试时,我们将获得完全不同的结果:

1
2
3
4
5
6
7
Benchmark                                                              Mode  Cnt    Score    Error  Units
BigMatrixMultiplicationBenchmarking.apacheCommonsMatrixMultiplication  avgt   20  511.140 ± 13.535   s/op
BigMatrixMultiplicationBenchmarking.coltMatrixMultiplication           avgt   20  197.914 ±  2.453   s/op
BigMatrixMultiplicationBenchmarking.ejmlMatrixMultiplication           avgt   20   25.830 ±  0.059   s/op
BigMatrixMultiplicationBenchmarking.homemadeMatrixMultiplication       avgt   20  497.493 ±  2.121   s/op
BigMatrixMultiplicationBenchmarking.la4jMatrixMultiplication           avgt   20   35.523 ±  0.102   s/op
BigMatrixMultiplicationBenchmarking.nd4jMatrixMultiplication           avgt   20    0.548 ±  0.006   s/op

我们可以看到,自制实现和Apache库现在比以前更糟,需要花费近10分钟的时间来执行两个矩阵的乘法。

马驹花了3分钟多一点,虽然更好,但仍然很长。 EJML和LA4J在将近30秒的时间内运行良好。 但是,ND4J赢得了该基准测试,其性能在CPU后端上不到一秒钟。

4.3。 分析

这表明基准测试结果确实取决于矩阵的特性,因此指出单个获胜者非常棘手。

5.结论

在本文中,我们学习了如何通过自己或使用外部库在Java中乘法矩阵。 在研究了所有解决方案之后,我们对所有解决方案进行了基准测试,结果发现,除了ND4J以外,它们在小型矩阵上的性能都非常好。 另一方面,在更大的矩阵上,ND4J处于领先地位。

像往常一样,可以在GitHub上找到本文的完整代码。