使Haskell代码更快

Making Haskell Code faster

有人知道如何让这个haskell代码更有趣吗?我在做Euler项目。此代码在4.029秒内运行:

1
2
3
4
5
6
7
8
collatz :: Int -> Int64 -> Int                                                                                                                                                
collatz c 1 = c                                                                
collatz c k                                                                    
    | even k    = collatz (c+1) (k `div` 2)                                    
    | otherwise = collatz (c+1) (3*k + 1)                                      

main = do                    
    print $ maximum (map (\i -> (collatz 1 i, i)) [1..1000000])

记忆collatz函数实际上增加了运行时,所以我没有做任何记忆。类似的C代码在0.239秒内运行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
int main(int argc, char *argv[])
{
    int maxlength = 0;
    int maxstart = 1;
    for (int i = 1; i <= 1000000; i++) {
        unsigned long k = i;
        int length = 1;
        while (k > 1) {
            length += 1;
            if (k % 2 == 0)
                k = k / 2;
            else
                k = 3*k + 1;
        }
        if (length > maxlength) {
            maxlength = length;
            maxstart = i;
        }
    }
    printf("%d, %d
"
, maxlength, maxstart);
    return 0;
}

haskell代码是用ghc-o3编译的,c代码是用gcc-std=c99-o3编译的。


我注意到这个问题很大程度上是一个停顿。看这里。

代码的主要问题是,默认情况下,GHC不优化整数除法。要手动修复我的代码,

1
2
3
collatz c k                                                                    
    | k .&. 1 == 0 = collatz (c+1) (k `shiftR` 1)                                    
    | otherwise    = collatz (c+1) (3*k + 1)

但是,如果在计算机上安装了LLVM,则可以使用

1
ghc -O2 -fllvm code.hs

LLVM执行必要的优化。这两种解决方案都使我的代码以大约0.420秒的速度运行,这更接近于可比较的C代码。


以下是haskell wiki的一个解决方案:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import Data.Array
import Data.List
import Data.Ord (comparing)

syrs n =
    a
    where
    a = listArray (1,n) $ 0:[1 + syr n x | x <- [2..n]]
    syr n x =
        if x' <= n then a ! x' else 1 + syr n x'
        where
        x' = if even x then x `div` 2 else 3 * x + 1

main =
    print $ maximumBy (comparing snd) $ assocs $ syrs 1000000

我机器上的计算时间:

1
2
3
4
5
6
haskell|master? ? ghc -O2 prob14_memoize.hs
[1 of 1] Compiling Main             ( prob14_memoize.hs, prob14_memoize.o )
Linking prob14_memoize ...
haskell|master? ? time ./prob14_memoize
(837799,524)
./prob14_memoize  0.63s user 0.03s system 99% cpu 0.664 total

与原件相比:

1
2
3
4
5
6
haskell|master? ? ghc -O2 prob14_orig.hs
[1 of 1] Compiling Main             ( prob14_orig.hs, prob14_orig.o )
Linking prob14_orig ...
haskell|master? ? time ./prob14_orig
(525,837799)
./prob14_orig  2.77s user 0.01s system 99% cpu 2.777 total