关于c ++:为什么我的Strassen的矩阵乘法变慢了?

Why is my Strassen's Matrix Multiplication slow?

我在C++中写了两个矩阵乘法程序:正则MM(源)和Strassen的MM(源),它们都在大小为2 ^ ^ k×2 ^ k的方阵上(换句话说,偶数大小的正方形矩阵)。

结果很糟糕。对于1024 x 1024矩阵,正则mm取46.381 sec,而strassen的mm取1484.303 sec(25 minutes!!!!!!)

我试图使代码尽可能简单。在Web上发现的其他Strassen的mm示例与我的代码没有太大的不同。斯特拉森代码的一个问题是显而易见的——我没有临界点,它切换到常规的mm。

我的Strassen的mm代码还有什么问题????

谢谢!

直接链接到源网址:http://pastebin.com/hqhtfpq9http://pastebin.com/usrq5tuy

编辑一。首先,有很多好的建议。感谢您抽出时间和分享知识。

我实现了更改(保留了所有代码),添加了截止点。毫米的2048x2048矩阵,与切断512已经给出了良好的结果。常规mm:191.49s斯特拉森毫米:112.179秒显著改善。在使用Intel Centrino处理器的史前联想X61 TableTPC上使用Visual Studio 2012获得了结果。我会做更多的检查(以确保得到正确的结果),并发布结果。


One issue with Strassen's code is obvious - I don't have cutoff point,
that switches to regular MM.

公平地说,递归到1点是(如果不是整个)问题的主要部分。尝试在不解决这一问题的情况下猜测其他性能瓶颈几乎是没有意义的,因为它带来了巨大的性能冲击。(换句话说,你把苹果比作橙子。)

正如注释中所讨论的,缓存对齐可能会产生影响,但不会达到这个比例。此外,缓存对齐可能比Strassen算法更损害常规算法,因为后者是缓存遗忘的。

1
2
3
4
5
6
7
void strassen(int **a, int **b, int **c, int tam) {

    // trivial case: when the matrix is 1 X 1:
    if (tam == 1) {
            c[0][0] = a[0][0] * b[0][0];
            return;
    }

那太小了。虽然strassen算法的复杂性较小,但它有一个更大的big-o常量。首先,函数调用开销一直降到1个元素。

这类似于使用合并或快速排序并一直递归到一个元素。为了提高效率,需要在大小变小时停止递归,然后返回到经典算法。

在快速/合并排序中,您将返回到低开销的O(n^2)插入或选择排序。这里你可以回到正常的O(n^3)矩阵乘法。

返回经典算法的阈值应该是一个可调的阈值,根据硬件和编译器优化代码的能力可能会有所不同。

对于像strassen乘法这样的东西,其优势仅仅是O(2.8074)比经典的O(n^3),如果这个阈值非常高,不要感到惊讶。(成千上万的元素?)

在某些应用中,可以有许多算法,每一种算法的复杂度都会降低,但大O值却会增加。结果是,多个算法在不同的大小下会变得最优。

大整数乘法就是一个臭名昭著的例子:

  • 小学乘法:o(n^2)最适合<~100位数字*
  • Karatsuba乘法:0(n^1.585)比上面快大约100位*
  • toom cook 3路:o(n^1.465)比karatusuba快约3000位*
  • 浮点fft:o(>n log(n))比karatusuba/toom-3快约700位*
  • SCH?nhage–Strassen算法(SSA):o(n log(n)loglog(n))比fft快约10亿位*
  • 固定宽度数理论转换:O(n对数(n)比SSA快几十亿位?*

*注意,这些示例阈值是近似值,可以大幅度变化-通常超过10的系数。


所以,可能还有更多的问题需要解决,但第一个问题是您使用的是指向数组的指针数组。由于您使用的数组大小是2的幂,这对于连续分配元素和使用整数除法将长数组的数字折叠成行来说是一个特别大的性能冲击。

不管怎样,这是我对一个问题的第一个猜测。正如我所说,可能会有更多的答案,当我发现它们时,我会补充这个答案。

编辑:这可能只会对问题造成少量影响。这个问题很可能是Luchian Grigore提到的涉及缓存线争用问题,其权力为2。

我验证了我的关注对于幼稚的算法是有效的。如果数组是连续的,那么简单算法的时间减少了近50%。这里是使用巴斯丁的代码(使用C++ 11依赖的平方矩阵类)。