Matrix Multiplication in Java
1.概述
在本教程中,我们将了解如何在Java中将两个矩阵相乘。
由于语言本身并不存在矩阵概念,因此我们将自己实现,并且还将与一些库一起工作,以了解它们如何处理矩阵乘法。
最后,我们将对我们探索的不同解决方案进行一些基准测试,以便确定最快的解决方案。
2.例子
让我们从建立一个示例开始,我们将在整个教程中进行引用。
首先,我们将想象一个3×2的矩阵:
现在让我们想象第二个矩阵,这次是两行乘四列:
然后,将第一个矩阵乘以第二个矩阵,将得到一个3×4的矩阵:
提醒一下,此结果是通过使用以下公式计算所得矩阵的每个像元获得的:
其中矩阵A的行数为,顺式矩阵Band的列数为矩阵A的列数,必须与矩阵B的行数匹配。
3.矩阵乘法
3.1。 自己实施
让我们从我们自己的矩阵实现开始。
我们将使其保持简单,仅使用二维双精度数组:
1 2 3 4 5 6 7 8 9 10 | double[][] firstMatrix = { new double[]{1d, 5d}, new double[]{2d, 3d}, new double[]{1d, 7d} }; double[][] secondMatrix = { new double[]{1d, 2d, 3d, 7d}, new double[]{5d, 2d, 8d, 1d} }; |
这些是我们示例的两个矩阵。 让我们创建一个期望的乘积结果:
1 2 3 4 5 | double[][] expected = { new double[]{26d, 12d, 43d, 12d}, new double[]{17d, 10d, 30d, 17d}, new double[]{36d, 16d, 59d, 14d} }; |
现在已经完成了所有设置,让我们实现乘法算法。 我们将首先创建一个空的结果数组,并遍历其单元格以在每个单元格中存储期望值:
1 2 3 4 5 6 7 8 9 10 11 | double[][] multiplyMatrices(double[][] firstMatrix, double[][] secondMatrix) { double[][] result = new double[firstMatrix.length][secondMatrix[0].length]; for (int row = 0; row < result.length; row++) { for (int col = 0; col < result[row].length; col++) { result[row][col] = multiplyMatricesCell(firstMatrix, secondMatrix, row, col); } } return result; } |
最后,让我们实现单个单元的计算。 为了实现这一点,我们将使用示例演示中前面显示的公式:
1 2 3 4 5 6 7 | double multiplyMatricesCell(double[][] firstMatrix, double[][] secondMatrix, int row, int col) { double cell = 0; for (int i = 0; i < secondMatrix.length; i++) { cell += firstMatrix[row][i] * secondMatrix[i][col]; } return cell; } |
最后,让我们检查算法的结果是否符合我们的预期结果:
1 2 | double[][] actual = multiplyMatrices(firstMatrix, secondMatrix); assertThat(actual).isEqualTo(expected); |
3.2。 EJML
我们将要看的第一个库是EJML,它表示Efficient Java Matrix Library。 在撰写本教程时,它是最近更新的Java矩阵库之一。 其目的是在计算和内存使用方面尽可能地高效。
我们必须将依赖项添加到pom.xml中的库中:
1 2 3 4 5 | <dependency> <groupId>org.ejml</groupId> <artifactId>ejml-all</artifactId> <version>0.38</version> </dependency> |
我们将使用与以前几乎相同的模式:根据我们的示例创建两个矩阵,并检查它们相乘的结果是否是我们先前计算的结果。
因此,让我们使用EJML创建矩阵。 为了实现这一点,我们将使用该库提供的SimpleMatrix类。
它可以将二维double数组作为其构造函数的输入:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | SimpleMatrix firstMatrix = new SimpleMatrix( new double[][] { new double[] {1d, 5d}, new double[] {2d, 3d}, new double[] {1d ,7d} } ); SimpleMatrix secondMatrix = new SimpleMatrix( new double[][] { new double[] {1d, 2d, 3d, 7d}, new double[] {5d, 2d, 8d, 1d} } ); |
现在,让我们定义乘法的期望矩阵:
1 2 3 4 5 6 7 | SimpleMatrix expected = new SimpleMatrix( new double[][] { new double[] {26d, 12d, 43d, 12d}, new double[] {17d, 10d, 30d, 17d}, new double[] {36d, 16d, 59d, 14d} } ); |
现在我们已经完成了所有的设置,让我们看看如何将两个矩阵相乘。 TheSimpleMatrix类提供amult()方法,该方法将另一个SimpleMatrix作为参数并返回两个矩阵的乘法:
1 | SimpleMatrix actual = firstMatrix.mult(secondMatrix); |
让我们检查获得的结果是否与预期的结果相符。
AsSimpleMatrix不会覆盖equals()方法,因此我们不能依靠它来进行验证。 但是,它提供了另一种选择:theisIdentical()方法,该方法不仅采用了另一个矩阵参数,而且还采用了双倍的容错能力,以忽略由于双精度带来的微小差异:
1 | assertThat(actual).matches(m -> m.isIdentical(expected, 0d)); |
到此为止,使用EJML库进行矩阵乘法。 让我们看看其他提供的东西。
3.3。 ND4J
现在让我们尝试ND4J库。 ND4J是一个计算库,是deeplearning4j项目的一部分。 ND4J尤其提供矩阵计算功能。
首先,我们必须获取库依赖:
1 2 3 4 5 | <dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-native</artifactId> <version>1.0.0-beta4</version> </dependency> |
请注意,我们此处使用的是Beta版本,因为GA版本似乎存在一些错误。
为了简洁起见,我们将不重写二维双精度数组,而只关注它们在每个库中的使用方式。 因此,对于ND4J,我们必须创建一个INDArray。 为此,我们将调用theNd4j.create()工厂方法,并将其传递给表示我们矩阵的double数组:
1 | INDArray matrix = Nd4j.create(/* a two dimensions double array */); |
与上一节一样,我们将创建三个矩阵:两个将要相乘,另一个是预期的结果。
之后,我们实际上想使用INDArray.mmul()方法在前两个矩阵之间进行乘法:
1 | INDArray actual = firstMatrix.mmul(secondMatrix); |
然后,我们再次检查实际结果是否与预期结果相符。 这次我们可以依靠相等性检查:
1 | assertThat(actual).isEqualTo(expected); |
这演示了如何使用ND4J库进行矩阵计算。
3.4。 Apache Commons
现在让我们谈谈Apache Commons Math3模块,该模块为我们提供了包括矩阵处理在内的数学计算。
同样,我们必须在pom.xml中指定依赖项:
1 2 3 4 5 | <dependency> <groupId>org.apache.commons</groupId> <artifactId>commons-math3</artifactId> <version>3.6.1</version> </dependency> |
设置完成后,我们可以使用RealMatrix接口及其itsArray2DRowRealMatrix实现来创建我们通常的矩阵。 实现类的构造函数将二维双精度数组作为其参数:
1 | RealMatrix matrix = new Array2DRowRealMatrix(/* a two dimensions double array */); |
对于矩阵乘法,RealMatrix接口提供了采用另一个RealMatrix参数的multiply()方法:
1 | RealMatrix actual = firstMatrix.multiply(secondMatrix); |
我们最终可以验证结果是否等于我们的预期:
1 | assertThat(actual).isEqualTo(expected); |
让我们看看下一个库!
3.5。 LA4J
这个名字叫LA4J,代表Java的线性代数。
让我们也为这个添加依赖项:
1 2 3 4 5 | <dependency> <groupId>org.la4j</groupId> <artifactId>la4j</artifactId> <version>0.6.0</version> </dependency> |
现在,LA4J的工作原理与其他库非常相似。 它提供带有aBasic2DMatrix实现的aMatrix接口,该实现将二维双精度数组作为输入:
1 | Matrix matrix = new Basic2DMatrix(/* a two dimensions double array */); |
与Apache Commons Math3模块一样,乘法方法是multimultily()并采用anotherMatrix作为其参数:
1 | Matrix actual = firstMatrix.multiply(secondMatrix); |
再一次,我们可以检查结果是否符合我们的期望:
1 | assertThat(actual).isEqualTo(expected); |
现在让我们看一下我们的最后一个库:Colt。
3.6。 小马
Colt是CERN开发的图书馆。 它提供了支持高性能科学和技术计算的功能。
与以前的库一样,我们必须获得正确的依赖关系:
1 2 3 4 5 | <dependency> <groupId>colt</groupId> <artifactId>colt</artifactId> <version>1.2.0</version> </dependency> |
为了使用Colt创建矩阵,我们必须使用DoubleFactory2D类。 它带有三个工厂实例:密集,稀疏和行压缩。 每个都经过优化以创建匹配的矩阵类型。
为了我们的目的,我们将使用密集实例。 这次,调用ismake()的方法将再次使用一个二维双精度数组,从而产生一个DoubleMatrix2D对象:
1 | DoubleMatrix2D matrix = doubleFactory2D.make(/* a two dimensions double array */); |
实例化矩阵后,我们将要对其进行乘法运算。 这次,矩阵对象上没有方法可以做到这一点。 我们必须创建一个Algebra类的实例,该实例具有amult()方法,并使用两个矩阵作为参数:
1 2 | Algebra algebra = new Algebra(); DoubleMatrix2D actual = algebra.mult(firstMatrix, secondMatrix); |
然后,我们可以将实际结果与预期结果进行比较:
1 | assertThat(actual).isEqualTo(expected); |
4.标杆管理
既然我们已经完成了矩阵乘法的各种可能性的探索,那么让我们检查一下哪个是性能最高的。
4.1。 小矩阵
让我们从小矩阵开始。 在这里,3×2和2×4矩阵。
为了实施性能测试,我们将使用JMH基准测试库
1 2 3 4 5 6 7 8 9 10 11 12 | public static void main(String[] args) throws Exception { Options opt = new OptionsBuilder() .include(MatrixMultiplicationBenchmarking.class.getSimpleName()) .mode(Mode.AverageTime) .forks(2) .warmupIterations(5) .measurementIterations(10) .timeUnit(TimeUnit.MICROSECONDS) .build(); new Runner(opt).run(); } |
这样,JMH将为每个用@Benchmark注释的方法进行两次完整运行,每个运行都有五次预热迭代(未计入平均计算)和十次测量。 至于测量,它将收集不同库的平均执行时间(以微秒为单位)。
然后,我们必须创建一个包含数组的状态对象:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | @State(Scope.Benchmark) public class MatrixProvider { private double[][] firstMatrix; private double[][] secondMatrix; public MatrixProvider() { firstMatrix = new double[][] { new double[] {1d, 5d}, new double[] {2d, 3d}, new double[] {1d ,7d} }; secondMatrix = new double[][] { new double[] {1d, 2d, 3d, 7d}, new double[] {5d, 2d, 8d, 1d} }; } } |
这样,我们可以确保数组初始化不属于基准测试。 之后,我们仍然必须使用MatrixProvider对象作为数据源来创建执行矩阵乘法的方法。 我们将在这里不再重复代码,因为我们之前看到了每个库。
最后,我们将使用main方法运行基准测试过程。 这给我们以下结果:
1 2 3 4 5 6 7 | Benchmark Mode Cnt Score Error Units MatrixMultiplicationBenchmarking.apacheCommonsMatrixMultiplication avgt 20 1,008 ± 0,032 us/op MatrixMultiplicationBenchmarking.coltMatrixMultiplication avgt 20 0,219 ± 0,014 us/op MatrixMultiplicationBenchmarking.ejmlMatrixMultiplication avgt 20 0,226 ± 0,013 us/op MatrixMultiplicationBenchmarking.homemadeMatrixMultiplication avgt 20 0,389 ± 0,045 us/op MatrixMultiplicationBenchmarking.la4jMatrixMultiplication avgt 20 0,427 ± 0,016 us/op MatrixMultiplicationBenchmarking.nd4jMatrixMultiplication avgt 20 12,670 ± 2,582 us/op |
如我们所见,EJML和Colt的性能非常好,每次操作大约只有五分之一微秒,而ND4j的性能较低,每次操作只有十微秒多一点。 其他库之间有表演。
另外,值得注意的是,当将预热迭代次数从5增加到10时,所有库的性能都会提高。
4.2。 大型矩阵
现在,如果我们采用更大的矩阵(例如3000×3000)会发生什么? 为了检查发生了什么,让我们首先创建另一个状态类,提供该大小的生成矩阵:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | @State(Scope.Benchmark) public class BigMatrixProvider { private double[][] firstMatrix; private double[][] secondMatrix; public BigMatrixProvider() {} @Setup public void setup(BenchmarkParams parameters) { firstMatrix = createMatrix(); secondMatrix = createMatrix(); } private double[][] createMatrix() { Random random = new Random(); double[][] result = new double[3000][3000]; for (int row = 0; row < result.length; row++) { for (int col = 0; col < result[row].length; col++) { result[row][col] = random.nextDouble(); } } return result; } } |
如我们所见,我们将创建3000×3000的二维双精度数组,其中填充了随机实数。
现在创建基准测试类:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | public class BigMatrixMultiplicationBenchmarking { public static void main(String[] args) throws Exception { Map<String, String> parameters = parseParameters(args); ChainedOptionsBuilder builder = new OptionsBuilder() .include(BigMatrixMultiplicationBenchmarking.class.getSimpleName()) .mode(Mode.AverageTime) .forks(2) .warmupIterations(10) .measurementIterations(10) .timeUnit(TimeUnit.SECONDS); new Runner(builder.build()).run(); } @Benchmark public Object homemadeMatrixMultiplication(BigMatrixProvider matrixProvider) { return HomemadeMatrix .multiplyMatrices(matrixProvider.getFirstMatrix(), matrixProvider.getSecondMatrix()); } @Benchmark public Object ejmlMatrixMultiplication(BigMatrixProvider matrixProvider) { SimpleMatrix firstMatrix = new SimpleMatrix(matrixProvider.getFirstMatrix()); SimpleMatrix secondMatrix = new SimpleMatrix(matrixProvider.getSecondMatrix()); return firstMatrix.mult(secondMatrix); } @Benchmark public Object apacheCommonsMatrixMultiplication(BigMatrixProvider matrixProvider) { RealMatrix firstMatrix = new Array2DRowRealMatrix(matrixProvider.getFirstMatrix()); RealMatrix secondMatrix = new Array2DRowRealMatrix(matrixProvider.getSecondMatrix()); return firstMatrix.multiply(secondMatrix); } @Benchmark public Object la4jMatrixMultiplication(BigMatrixProvider matrixProvider) { Matrix firstMatrix = new Basic2DMatrix(matrixProvider.getFirstMatrix()); Matrix secondMatrix = new Basic2DMatrix(matrixProvider.getSecondMatrix()); return firstMatrix.multiply(secondMatrix); } @Benchmark public Object nd4jMatrixMultiplication(BigMatrixProvider matrixProvider) { INDArray firstMatrix = Nd4j.create(matrixProvider.getFirstMatrix()); INDArray secondMatrix = Nd4j.create(matrixProvider.getSecondMatrix()); return firstMatrix.mmul(secondMatrix); } @Benchmark public Object coltMatrixMultiplication(BigMatrixProvider matrixProvider) { DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense; DoubleMatrix2D firstMatrix = doubleFactory2D.make(matrixProvider.getFirstMatrix()); DoubleMatrix2D secondMatrix = doubleFactory2D.make(matrixProvider.getSecondMatrix()); Algebra algebra = new Algebra(); return algebra.mult(firstMatrix, secondMatrix); } } |
当运行此基准测试时,我们将获得完全不同的结果:
1 2 3 4 5 6 7 | Benchmark Mode Cnt Score Error Units BigMatrixMultiplicationBenchmarking.apacheCommonsMatrixMultiplication avgt 20 511.140 ± 13.535 s/op BigMatrixMultiplicationBenchmarking.coltMatrixMultiplication avgt 20 197.914 ± 2.453 s/op BigMatrixMultiplicationBenchmarking.ejmlMatrixMultiplication avgt 20 25.830 ± 0.059 s/op BigMatrixMultiplicationBenchmarking.homemadeMatrixMultiplication avgt 20 497.493 ± 2.121 s/op BigMatrixMultiplicationBenchmarking.la4jMatrixMultiplication avgt 20 35.523 ± 0.102 s/op BigMatrixMultiplicationBenchmarking.nd4jMatrixMultiplication avgt 20 0.548 ± 0.006 s/op |
我们可以看到,自制实现和Apache库现在比以前更糟,需要花费近10分钟的时间来执行两个矩阵的乘法。
马驹花了3分钟多一点,虽然更好,但仍然很长。 EJML和LA4J在将近30秒的时间内运行良好。 但是,ND4J赢得了该基准测试,其性能在CPU后端上不到一秒钟。
4.3。 分析
这表明基准测试结果确实取决于矩阵的特性,因此指出单个获胜者非常棘手。
5.结论
在本文中,我们学习了如何通过自己或使用外部库在Java中乘法矩阵。 在研究了所有解决方案之后,我们对所有解决方案进行了基准测试,结果发现,除了ND4J以外,它们在小型矩阵上的性能都非常好。 另一方面,在更大的矩阵上,ND4J处于领先地位。
像往常一样,可以在GitHub上找到本文的完整代码。