关于collatz:为什么这个简单的haskell算法这么慢?

Why is this simple haskell algorithm so slow?

扰流器警报:这与来自Project Euler的问题14有关。

以下代码运行大约需要15秒。我有一个在1s中运行的非递归Java解决方案。我想我应该能够更接近这个代码。

1
2
3
4
5
6
7
8
9
import Data.List

collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

main = do
  print ((foldl1' max) . map (collatz 1) $ [1..1000000])

我已经用+RHS -p进行了分析,注意到分配的内存很大,并且随着输入的增长而增长。对于n = 100,0001GB是分配给(!),用于n = 1,000,00013GB(!!)已分配。

然后,-sstderr再次表明,虽然分配了大量字节,但总内存使用量为1MB,生产率为95%+,因此13GB可能是Red Herring。

我可以想到一些可能性:

  • 有些事情没有它需要的那么严格。我已经发现了foldl1',但也许我需要做更多?有没有可能标记collatz严格(这有道理吗?

  • collatz不是尾调用优化。我想应该是,但不要知道如何确认。

  • 编译器没有进行一些优化,我认为它应该这样做-例如一次只需要在内存中存储两个collatz结果(max和current)

  • 有什么建议吗?

    这几乎是一个复制品,为什么这个haskell表达这么慢?虽然我会注意到,快速Java解决方案不必执行任何记忆。有没有什么方法可以在不必求助于它的情况下加快速度?

    以下是我的分析输出,供参考:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
      Wed Dec 28 09:33 2011 Time and Allocation Profiling Report  (Final)

         scratch +RTS -p -hc -RTS

      total time  =        5.12 secs   (256 ticks @ 20 ms)
      total alloc = 13,229,705,716 bytes  (excludes profiling overheads)

    COST CENTRE                    MODULE               %time %alloc

    collatz                        Main                  99.6   99.4


                                                                                                   individual    inherited
    COST CENTRE              MODULE                                               no.    entries  %time %alloc   %time %alloc

    MAIN                     MAIN                                                   1           0   0.0    0.0   100.0  100.0
     CAF                     Main                                                 208          10   0.0    0.0   100.0  100.0
      collatz                Main                                                 215           1   0.0    0.0     0.0    0.0
      main                   Main                                                 214           1   0.4    0.6   100.0  100.0
       collatz               Main                                                 216           0  99.6   99.4    99.6   99.4
     CAF                     GHC.IO.Handle.FD                                     145           2   0.0    0.0     0.0    0.0
     CAF                     System.Posix.Internals                               144           1   0.0    0.0     0.0    0.0
     CAF                     GHC.Conc                                             128           1   0.0    0.0     0.0    0.0
     CAF                     GHC.IO.Handle.Internals                              119           1   0.0    0.0     0.0    0.0
     CAF                     GHC.IO.Encoding.Iconv                                113           5   0.0    0.0     0.0    0.0

    和-sstderr:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    ./scratch +RTS -sstderr
    525
      21,085,474,908 bytes allocated in the heap
          87,799,504 bytes copied during GC
               9,420 bytes maximum residency (1 sample(s))          
              12,824 bytes maximum slop              
                   1 MB total memory in use (0 MB lost due to fragmentation)  

      Generation 0: 40219 collections,     0 parallel,  0.40s,  0.51s elapsed
      Generation 1:     1 collections,     0 parallel,  0.00s,  0.00s elapsed

      INIT  time    0.00s  (  0.00s elapsed)
      MUT   time   35.38s  ( 36.37s elapsed)
      GC    time    0.40s  (  0.51s elapsed)
      RP    time    0.00s  (  0.00s elapsed)  PROF  time    0.00s  (  0.00s elapsed)
      EXIT  time    0.00s  (  0.00s elapsed)
      Total time   35.79s  ( 36.88s elapsed)  %GC time       1.1%  (1.4% elapsed)  Alloc rate    595,897,095 bytes per MUT second

      Productivity  98.9% of total user, 95.9% of total elapsed

    和Java解决方案(不是我的,取自项目Euler论坛与记忆化删除):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    public class Collatz {
      public int getChainLength( int n )
      {
        long num = n;
        int count = 1;
        while( num > 1 )
        {
          num = ( num%2 == 0 ) ? num >> 1 : 3*num+1;
          count++;
        }
        return count;
      }

      public static void main(String[] args) {
        Collatz obj = new Collatz();
        long tic = System.currentTimeMillis();
        int max = 0, len = 0, index = 0;
        for( int i = 3; i < 1000000; i++ )
        {
          len = obj.getChainLength(i);
          if( len > max )
          {
            max = len;
            index = i;
          }
        }
        long toc = System.currentTimeMillis();
        System.out.println(toc-tic);
        System.out.println("Index:" + index +", length =" + max );
      }
    }


    一开始,我想你应该试着在a前面加一个感叹号。

    1
    2
    3
    4
    collatz !a 1  = a
    collatz !a x
      | even x    = collatz (a + 1) (x `div` 2)
      | otherwise = collatz (a + 1) (3 * x + 1)

    (您需要将{-# LANGUAGE BangPatterns #-}放在源文件的顶部,这样才能工作。)

    我的理由是:问题是你在第一个论点中对collatz进行了大量的讨论:它开始是1,然后变成1 + 1,然后变成(1 + 1) + 1……都没有被强迫过。这种bang模式强制collatz的第一个参数在每次调用时都被强制执行,因此它以1开始,然后变为2,依此类推,而不会生成一个大的未赋值thunk:它只是保持为整数。

    注意,bang模式只是使用seq的简写,在这种情况下,我们可以将collatz改写如下:

    1
    2
    3
    4
    5
    collatz a _ | seq a False = undefined
    collatz a 1  = a
    collatz a x
      | even x    = collatz (a + 1) (x `div` 2)
      | otherwise = collatz (a + 1) (3 * x + 1)

    这里的诀窍是强制一个在警卫,然后总是评估为错误(所以身体是无关的)。然后评估继续进行下一个案例,一个已经被评估过的案例。然而,爆炸模式更清晰。

    不幸的是,当用-O2编译时,它的运行速度并不比原来快!我们还可以尝试什么?我们可以做的一件事是假设这两个数字永远不会溢出机器大小的整数,并给出collatz这种类型的注释:

    1
    collatz :: Int -> Int -> Int

    我们将把bang模式留在这里,因为我们仍然应该避免构建thunk,即使它们不是性能问题的根源。这使我的(慢)计算机的时间降到8.5秒。

    下一步是尝试将其更接近Java解决方案。首先要认识到的是,在haskell中,div对于负整数的行为在数学上更为正确,但比在haskell中称为quot的"正常"c除法慢。用EDCOX1(11)代替EDCOX1×10,将运行时间降低到5.2秒,用EDCOX1×15(导入数据位)替换EDCOX1×14,以匹配Java解决方案使其下降到4.9秒。

    这和我现在所能得到的一样低,但我认为这是一个很好的结果,因为你的电脑比我的电脑快,所以它应该更接近Java解决方案。

    这是最后的代码(我在路上做了一些清理):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    {-# LANGUAGE BangPatterns #-}

    import Data.Bits
    import Data.List

    collatz :: Int -> Int
    collatz = collatz' 1
      where collatz' :: Int -> Int -> Int
            collatz' !a 1 = a
            collatz' !a x
              | even x    = collatz' (a + 1) (x `shiftR` 1)
              | otherwise = collatz' (a + 1) (3 * x + 1)

    main :: IO ()
    main = print . foldl1' max . map collatz $ [1..1000000]

    看看这个程序的GHC核心(使用ghc-core),我认为这可能是最好的;collatz循环使用未固定的整数,而程序的其余部分看起来正常。我能想到的唯一改进是从map collatz [1..1000000]迭代中消除拳击。

    顺便说一句,不要担心"total alloc"这个数字;它是在程序的整个生命周期中分配的总内存,即使GC回收内存,它也不会减少。多兆字节的数字很常见。


    您可能会丢失列表和bang模式,但仍然可以使用堆栈获得相同的性能。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    import Data.List
    import Data.Bits

    coll :: Int -> Int
    coll 0 = 0
    coll 1 = 1
    coll 2 = 2
    coll n =
      let a = coll (n - 1)
          collatz a 1 = a
          collatz a x
            | even x    = collatz (a + 1) (x `shiftR` 1)
            | otherwise = collatz (a + 1) (3 * x + 1)
      in max a (collatz 1 n)


    main = do
      print $ coll 100000

    这方面的一个问题是,对于大型输入(如1_U 000),您必须增加堆栈的大小。

    更新:

    这里有一个尾部递归版本,它不受堆栈溢出问题的影响。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    import Data.Word
    collatz :: Word -> Word -> (Word, Word)
    collatz a x
      | x == 1    = (a,x)
      | even x    = collatz (a + 1) (x `quot` 2)
      | otherwise = collatz (a + 1) (3 * x + 1)

    coll :: Word -> Word
    coll n = collTail 0 n
      where
        collTail m 1 = m
        collTail m n = collTail (max (fst $ collatz 1 n) m) (n-1)

    注意使用Word而不是Int。它在性能上有所不同。如果您愿意的话,您仍然可以使用bang模式,这将使性能增加一倍。


    我发现有一件事在这个问题上有着惊人的不同。我坚持直截了当的递归关系,而不是折页,你应该原谅这个表达,用它来计数。重写

    1
    collatz n = if even n then n `div` 2 else 3 * n + 1

    作为

    1
    2
    3
    collatz n = case n `divMod` 2 of
                (n', 0) -> n'
                _       -> 3 * n + 1

    在2.8 GHz Athlon II x4 430 CPU的系统上,我的程序运行时间缩短了1.2秒。我最初更快的版本(使用divmod后2.3秒):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    {-# LANGUAGE BangPatterns #-}

    import Data.List
    import Data.Ord

    collatzChainLen :: Int -> Int
    collatzChainLen n = collatzChainLen' n 1
        where collatzChainLen' n !l
                | n == 1    = l
                | otherwise = collatzChainLen' (collatz n) (l + 1)

    collatz:: Int -> Int
    collatz n = case n `divMod` 2 of
                     (n', 0) -> n'
                     _       -> 3 * n + 1

    pairMap :: (a -> b) -> [a] -> [(a, b)]
    pairMap f xs = [(x, f x) | x <- xs]

    main :: IO ()
    main = print $ fst (maximumBy (comparing snd) (pairMap collatzChainLen [1..999999]))

    一个更为惯用的haskell版本大约只需9.7秒(使用divmod时为8.5秒);除了

    1
    2
    collatzChainLen :: Int -> Int
    collatzChainLen n = 1 + (length . takeWhile (/= 1) . (iterate collatz)) n

    使用data.list.stream应该允许流融合,这将使这个版本在显式累积的情况下运行得更像,但是我找不到具有data.list.stream的Ubuntu libghc*包,因此我还无法验证这一点。