关于c ++:无分支K-means(或其他优化)

Branchless K-means (or other optimizations)

注意:我会更欣赏如何处理和提出这些类型的解决方案而不是解决方案本身的指南。

在我的系统中,我有一个非常关键的性能函数,它在特定的上下文中显示为头号分析热点。它位于k均值迭代的中间(已经使用并行处理每个工作线程中的子范围点)。

1
2
3
4
5
6
7
8
9
10
11
12
13
ClusterPoint& pt = points[j];
pt.min_index = -1;
pt.min_dist = numeric_limits<float>::max();
for (int i=0; i < num_centroids; ++i)
{
    const ClusterCentroid& cent = centroids[i];
    const float dist = ...;
    if (dist < pt.min_dist) // <-- #1 hotspot
    {
        pt.min_dist = dist;
        pt.min_index = i;
    }
}

处理这一段代码所需的时间的任何节省都是非常重要的,所以我经常处理它。例如,将质心环放在外部,并对给定的质心并行地遍历这些点可能是值得的。这里的簇点数以百万计,而形心数以千计。该算法应用于少量迭代(通常小于10次)。它不寻求完美的收敛性/稳定性,只是一些"合理"的近似。

任何想法都是值得赞赏的,但我真正渴望发现的是,是否可以像使用SIMD版本那样将此代码无分支化。我还没有真正开发出一种能够轻松掌握如何想出无分支解决方案的心理能力:我的大脑在这方面的失败与我早期第一次接触递归时的情况非常相似,因此,关于如何编写无分支代码以及如何为其开发适当的思维方式的指南也会有所帮助。

简而言之,我正在寻找关于如何微优化此代码的任何指南、提示和建议(不一定是解决方案)。它很可能有改进算法的空间,但我的盲点一直是在微优化解决方案中(我很想知道如何更有效地应用它们,而不是过分地使用它)。它已经紧密的多线程逻辑并行块,所以我几乎被推到微观优化的角落,作为一个更快的事情尝试没有一个更智能的算法彻底。我们完全可以更改内存布局。

响应算法建议

关于寻找微观优化O(knm)算法的错误观点,我完全同意这一观点,该算法可以在算法层面上得到明显的改进。这将这个特定的问题推向一个学术和不切实际的领域。然而,如果可以允许我讲一个轶事,我来自于高层编程的原始背景——非常强调广泛的、大规模的观点、安全性,而很少强调低级实现细节。我最近把项目换成了一种非常不同的现代风味项目,我正在从我的缓存效率、GPGPU、无分支技术、SIMD、特殊用途的MEM分配器(实际上比malloc更好)等同行那里学习各种新技巧。

在这里,我正努力赶上最新的性能趋势,令人惊讶的是,我发现那些在90年代我经常喜欢的旧数据结构(通常是链接/树型结构)实际上被更幼稚、更野蛮、微优化、并行代码(在连续的操作系统上应用调整过的指令)大大超过了。美国内存块。同时,这有点令人失望,因为我觉得我们现在更适合机器的算法,并通过这种方式缩小了可能性(尤其是GPGPU)。

最有趣的是,我发现这种微优化、快速的数组处理代码比我以前使用的复杂算法和数据结构更容易维护。首先,它们更容易概括。此外,我的同事经常会对某个领域的特定减速情况提出客户投诉,只是简单地对某些SIMD(可能还有一些SIMD)进行并行处理,并以相当快的速度将其完成。算法改进通常可以提供更多,但是这些微优化应用的速度和非侵入性让我想在这方面了解更多,因为阅读关于更好算法的论文可能需要一些时间(以及需要更广泛的更改)。因此,我最近更倾向于这种微观优化的潮流,在这个特定的案例中,可能有点太多了,但是我的好奇心更多的是扩展我在任何情况下可能的解决方案的范围。

拆卸

注意:我真的非常不擅长汇编,因此我经常以一种尝试性和错误性的方式对事物进行更多的调优,对为什么在"vtune"中显示的热点可能是瓶颈进行一些有根据的猜测,然后尝试看时间是否改善,假设这些猜测在时间确实改善或完全改善的情况下具有一定的真实性提示。如果他们没有的话就错过了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
000007FEEE3FB8A1  jl          thread_partition+70h (7FEEE3FB780h)
    {
        ClusterPoint& pt = points[j];
        pt.min_index = -1;
        pt.min_dist = numeric_limits<float>::max();
        for (int i = 0; i < num_centroids; ++i)
000007FEEE3FB8A7  cmp         ecx,r10d
000007FEEE3FB8AA  jge         thread_partition+1F4h (7FEEE3FB904h)
000007FEEE3FB8AC  lea         rax,[rbx+rbx*2]
000007FEEE3FB8B0  add         rax,rax
000007FEEE3FB8B3  lea         r8,[rbp+rax*8+8]
        {
            const ClusterCentroid& cent = centroids[i];
            const float x = pt.pos[0] - cent.pos[0];
            const float y = pt.pos[1] - cent.pos[1];
000007FEEE3FB8B8  movss       xmm0,dword ptr [rdx]
            const float z = pt.pos[2] - cent.pos[2];
000007FEEE3FB8BC  movss       xmm2,dword ptr [rdx+4]
000007FEEE3FB8C1  movss       xmm1,dword ptr [rdx-4]
000007FEEE3FB8C6  subss       xmm2,dword ptr [r8]
000007FEEE3FB8CB  subss       xmm0,dword ptr [r8-4]
000007FEEE3FB8D1  subss       xmm1,dword ptr [r8-8]
            const float dist = x*x + y*y + z*z;
000007FEEE3FB8D7  mulss       xmm2,xmm2
000007FEEE3FB8DB  mulss       xmm0,xmm0
000007FEEE3FB8DF  mulss       xmm1,xmm1
000007FEEE3FB8E3  addss       xmm2,xmm0
000007FEEE3FB8E7  addss       xmm2,xmm1

            if (dist < pt.min_dist)
// VTUNE HOTSPOT
000007FEEE3FB8EB  comiss      xmm2,dword ptr [rdx-8]
000007FEEE3FB8EF  jae         thread_partition+1E9h (7FEEE3FB8F9h)
            {
                pt.min_dist = dist;
000007FEEE3FB8F1  movss       dword ptr [rdx-8],xmm2
                pt.min_index = i;
000007FEEE3FB8F6  mov         dword ptr [rdx-10h],ecx
000007FEEE3FB8F9  inc         ecx  
000007FEEE3FB8FB  add         r8,30h
000007FEEE3FB8FF  cmp         ecx,r10d
000007FEEE3FB902  jl          thread_partition+1A8h (7FEEE3FB8B8h)
    for (int j = *irange.first; j < *irange.last; ++j)
000007FEEE3FB904  inc         edi  
000007FEEE3FB906  add         rdx,20h
000007FEEE3FB90A  cmp         edi,dword ptr [rsi+4]
000007FEEE3FB90D  jl          thread_partition+31h (7FEEE3FB741h)
000007FEEE3FB913  mov         rbx,qword ptr [irange]
            }
        }
    }
}

我们被迫将目标锁定在SSE 2上——这在我们的时代有点落后,但当我们假设即使SSE 4作为最低要求也可以(用户有一些Intel机器的原型)时,用户群实际上也出现了一次失误。

用Standalon更新


太糟糕了,我们不能使用SSE4.1,但是很好,SSE2就是这样。我还没有测试过这个,只是编译它以查看是否有语法错误,并查看程序集是否有意义(虽然GCC溢出了min_index,尽管有些xmm寄存器没有使用,但不确定为什么会发生这种情况)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
int find_closest(float *x, float *y, float *z,
                 float pt_x, float pt_y, float pt_z, int n) {
    __m128i min_index = _mm_set_epi32(3, 2, 1, 0);
    __m128 xdif = _mm_sub_ps(_mm_set1_ps(pt_x), _mm_load_ps(x));
    __m128 ydif = _mm_sub_ps(_mm_set1_ps(pt_y), _mm_load_ps(y));
    __m128 zdif = _mm_sub_ps(_mm_set1_ps(pt_z), _mm_load_ps(z));
    __m128 min_dist = _mm_add_ps(_mm_add_ps(_mm_mul_ps(xdif, xdif),
                                            _mm_mul_ps(ydif, ydif)),
                                            _mm_mul_ps(zdif, zdif));
    __m128i index = min_index;
    for (int i = 4; i < n; i += 4) {
        xdif = _mm_sub_ps(_mm_set1_ps(pt_x), _mm_load_ps(x + i));
        ydif = _mm_sub_ps(_mm_set1_ps(pt_y), _mm_load_ps(y + i));
        zdif = _mm_sub_ps(_mm_set1_ps(pt_z), _mm_load_ps(z + i));
        __m128 dist = _mm_add_ps(_mm_add_ps(_mm_mul_ps(xdif, xdif),
                                            _mm_mul_ps(ydif, ydif)),
                                            _mm_mul_ps(zdif, zdif));
        index = _mm_add_epi32(index, _mm_set1_epi32(4));
        __m128i mask = _mm_castps_si128(_mm_cmplt_ps(dist, min_dist));
        min_dist = _mm_min_ps(min_dist, dist);
        min_index = _mm_or_si128(_mm_and_si128(index, mask),
                                 _mm_andnot_si128(mask, min_index));
    }
    float mdist[4];
    _mm_store_ps(mdist, min_dist);
    uint32_t mindex[4];
    _mm_store_si128((__m128i*)mindex, min_index);
    float closest = mdist[0];
    int closest_i = mindex[0];
    for (int i = 1; i < 4; i++) {
        if (mdist[i] < closest) {
            closest = mdist[i];
            closest_i = mindex[i];
        }
    }
    return closest_i;
}

和往常一样,它期望指针16对齐。另外,填充应该是无限远的点(所以它们永远不会离目标最近)。

SSE 4.1可以让您更换这个

1
2
min_index = _mm_or_si128(_mm_and_si128(index, mask),
                         _mm_andnot_si128(mask, min_index));

以此

1
min_index = _mm_blendv_epi8(min_index, index, mask);

这里有一个ASM版本,是为vsyasm开发的,经过了一些测试(似乎可以工作)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
bits 64

section .data

align 16
centroid_four:
    dd 4, 4, 4, 4
centroid_index:
    dd 0, 1, 2, 3

section .text

global find_closest

proc_frame find_closest
    ;
    ;   arguments:
    ;       ecx: number of points (multiple of 4 and at least 4)
    ;       rdx -> array of 3 pointers to floats (x, y, z) (the points)
    ;       r8 -> array of 3 floats (the reference point)
    ;
    alloc_stack 0x58
    save_xmm128 xmm6, 0
    save_xmm128 xmm7, 16
    save_xmm128 xmm8, 32
    save_xmm128 xmm9, 48
[endprolog]
    movss xmm0, [r8]
    shufps xmm0, xmm0, 0
    movss xmm1, [r8 + 4]
    shufps xmm1, xmm1, 0
    movss xmm2, [r8 + 8]
    shufps xmm2, xmm2, 0
    ; pointers to x, y, z in r8, r9, r10
    mov r8, [rdx]
    mov r9, [rdx + 8]
    mov r10, [rdx + 16]
    ; reference point is in xmm0, xmm1, xmm2 (x, y, z)
    movdqa xmm3, [rel centroid_index]   ; min_index
    movdqa xmm4, xmm3                   ; current index
    movdqa xmm9, [rel centroid_four]     ; index increment
    paddd xmm4, xmm9
    ; calculate initial min_dist, xmm5
    movaps xmm5, [r8]
    subps xmm5, xmm0
    movaps xmm7, [r9]
    subps xmm7, xmm1
    movaps xmm8, [r10]
    subps xmm8, xmm2
    mulps xmm5, xmm5
    mulps xmm7, xmm7
    mulps xmm8, xmm8
    addps xmm5, xmm7
    addps xmm5, xmm8
    add r8, 16
    add r9, 16
    add r10, 16
    sub ecx, 4
    jna _tail
_loop:
    movaps xmm6, [r8]
    subps xmm6, xmm0
    movaps xmm7, [r9]
    subps xmm7, xmm1
    movaps xmm8, [r10]
    subps xmm8, xmm2
    mulps xmm6, xmm6
    mulps xmm7, xmm7
    mulps xmm8, xmm8
    addps xmm6, xmm7
    addps xmm6, xmm8
    add r8, 16
    add r9, 16
    add r10, 16
    movaps xmm7, xmm6
    cmpps xmm6, xmm5, 1
    minps xmm5, xmm7
    movdqa xmm7, xmm6
    pand xmm6, xmm4
    pandn xmm7, xmm3
    por xmm6, xmm7
    movdqa xmm3, xmm6
    paddd xmm4, xmm9
    sub ecx, 4
    ja _loop
_tail:
    ; calculate horizontal minumum
    pshufd xmm0, xmm5, 0xB1
    minps xmm0, xmm5
    pshufd xmm1, xmm0, 0x4E
    minps xmm0, xmm1
    ; find index of the minimum
    cmpps xmm0, xmm5, 0
    movmskps eax, xmm0
    bsf eax, eax
    ; index into xmm3, sort of
    movaps [rsp + 64], xmm3
    mov eax, [rsp + 64 + rax * 4]
    movaps xmm9, [rsp + 48]
    movaps xmm8, [rsp + 32]
    movaps xmm7, [rsp + 16]
    movaps xmm6, [rsp]
    add rsp, 0x58
    ret
endproc_frame

在C++中:

1
extern"C" int find_closest(int n, float** points, float* reference_point);


可以使用无分支三元运算符,有时称为bitselect(条件?真:假。只需对两个成员使用它,默认为不做任何事情。不要担心额外的操作,它们与if语句分支相比没有什么区别。

位选择实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
inline static int bitselect(int condition, int truereturnvalue, int falsereturnvalue)
{
    return (truereturnvalue & -condition) | (falsereturnvalue & ~(-condition)); //a when TRUE and b when FALSE
}

inline static float bitselect(int condition, float truereturnvalue, float falsereturnvalue)
{
    //Reinterpret floats. Would work because it's just a bit select, no matter the actual value
    int& at = reinterpret_cast<int&>(truereturnvalue);
    int& af = reinterpret_cast<int&>(falsereturnvalue);
    int res = (at & -condition) | (af & ~(-condition)); //a when TRUE and b when FALSE
    return  reinterpret_cast<float&>(res);
}

你的循环应该是这样的:

1
2
3
4
5
6
7
8
9
10
for (int i=0; i < num_centroids; ++i)
{
  const ClusterCentroid& cent = centroids[i];
  const float dist = ...;
  bool isSmaeller = dist < pt.min_dist;

  //use same value if not smaller
  pt.min_index = bitselect(isSmaeller, i, pt.min_index);
  pt.min_dist = bitselect(isSmaeller, dist, pt.min_dist);
}


C++是一种高级语言。假设C++源代码中的控制流转化为分支指令是有缺陷的。我没有您示例中某些类型的定义,因此我做了一个简单的带有类似条件赋值的测试程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int g(int, int);

int f(const int *arr)
{
    int min = 10000, minIndex = -1;
    for ( int i = 0; i < 1000; ++i )
    {
        if ( arr[i] < min )
        {
            min = arr[i];
            minIndex = i;
        }
    }
    return g(min, minIndex);
}

注意,使用未定义的"g"仅仅是为了防止优化器删除所有内容。我用带有-o3和-s的g++4.9.2将其翻译成x86_程序集(甚至不必更改-march的默认值),结果(不太令人惊讶)是循环体不包含分支

1
2
3
4
5
6
movl    (%rdi,%rax,4), %ecx
movl    %edx, %r8d
cmpl    %edx, %ecx
cmovle  %ecx, %r8d
cmovl   %eax, %esi
addq    $1, %rax

除此之外,认为无分支必然更快的假设也可能存在缺陷,因为新距离"击败"旧距离的概率正在减少您所看到的更多元素。这不是掷硬币。"bitselect"技巧是在编译器在生成"好像"程序集时比现在更不积极时发明的。我更愿意建议您先看看编译器实际生成的程序集类型,然后再尝试重新编写代码,以便编译器能够更好地优化它,或者将结果作为手写程序集的基础。如果您想研究SIMD,我建议您使用"最小值"方法来减少数据依赖性(在我的示例中,对"最小值"的依赖性可能是一个瓶颈)。


首先,我建议您在尝试任何代码更改之前,先看看优化构建中的反汇编。理想情况下,您希望在程序集级别查看探查器数据。这可以显示各种情况,例如:

  • 编译器可能没有生成实际的分支指令。
  • 具有瓶颈的代码行可能具有比您想象的更多的相关指令,例如dist计算。
  • 除此之外,还有一个标准技巧,即当你谈论距离计算时,通常需要平方根。你应该在过程的最后对最小平方值做平方根。

    SSE可以一次处理四个值,没有任何分支,使用mm_min_ps。如果您真的需要速度,那么您希望使用SSE(或AVX)内部函数。下面是一个基本示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
      float MinimumDistance(const float *values, int count)
      {
        __m128 min = _mm_set_ps(FLT_MAX, FLT_MAX, FLT_MAX, FLT_MAX);
        int i=0;
        for (; i < count - 3; i+=4)
        {
            __m128 distances = _mm_loadu_ps(&values[i]);
            min = _mm_min_ps(min, distances);
        }
        // Combine the four separate minimums to a single value
        min = _mm_min_ps(min, _mm_shuffle_ps(min, min, _MM_SHUFFLE(2, 3, 0, 1)));
        min = _mm_min_ps(min, _mm_shuffle_ps(min, min, _MM_SHUFFLE(1, 0, 3, 2)));

        // Deal with the last 0-3 elements the slow way
        float result = FLT_MAX;
        if (count > 3) _mm_store_ss(&result, min);
        for (; i < count; i++)
        {
            result = min(values[i], result);
        }

        return result;
      }

    为了获得最佳的SSE性能,您应该确保负载发生在对齐的地址上。如有必要,可以用与上面代码中最后几个元素相同的方法处理前几个未对齐的元素。

    另一个需要注意的是内存带宽。如果在这个循环中没有使用clusterCentroid结构的几个成员,那么从内存中读取的数据将比实际需要的要多,因为内存是在缓存行大小的块中读取的,每个块为64字节。


    这可能是双向的,但我将尝试以下结构:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    std::vector<float> centDists(num_centroids); //<-- one for each thread.
    for (size_t p=0; p<num_points; ++p) {
        Point& pt = points[p];
        for (size_t c=0; c<num_centroids; ++c) {
            const float dist = ...;
            centDists[c]=dist;
        }
        pt.min_idx it= min_element(centDists.begin(),centDists.end())-centDists.begin();    
    }

    显然,您现在必须在内存上迭代两次,这可能会影响缓存命中率(您也可以将其拆分为子范围),但另一方面,每个内部循环都应该易于向量化和展开,因此您只需测量它是否值得。

    即使你坚持使用你的版本,我也会尝试使用局部变量来跟踪最小索引和距离,并将结果应用到最后。合理的做法是,每个对pt.min_dist的读或写都是通过指针有效地完成的,这取决于编译器的优化,可能会降低性能,也可能不会降低性能。

    另一件对矢量化很重要的事情是将一个结构数组(在本例中是cententroid)转换为一个数组结构(例如,每个点的坐标都有一个数组),因为这样您就不需要额外的收集指令来加载数据以用于SIMD指令。有关该主题的更多信息,请参阅EricBrumer的演讲。

    编辑:我的系统的一些数字(Haswell,Clang 3.5):我用你的基准测试了一个简短的测试,在我的系统中,上面的代码将算法的速度降低了大约10%——基本上,没有什么可以向量化的。

    然而,在将AOS应用于质心的SOA转换时,距离计算是矢量化的,这导致与将AOS应用于SOA转换的原始结构相比,总体运行时间减少了约40%。


    一种可能的微观优化:在局部变量中存储最小距离和最小索引。编译器可能需要更频繁地以您编写的方式写入内存;在某些体系结构上,这可能会对性能产生很大影响。另一个例子见我的答案。

    亚当斯建议马上做4个比较也是一个很好的比较。

    然而,你最好的加速将来自于减少你必须检查的质心的数量。理想情况下,围绕质心构建一个kd树(或类似树),然后查询该树以找到最近的点。

    如果您周围没有任何树构建代码,下面是我最喜欢的"穷人"最近点搜索:

    1
    2
    3
    4
    Sort the points by one coordinate, e.g. cent.pos[0]
    Pick a starting index for the query point (pt)
    Iterate forwards through the candidate points until you reach the end, OR when abs(pt.pos[0] - cent.pos[0]) > min_dist
    Repeat the previous step going the opposite direction.

    额外的搜索停止条件意味着你应该跳过相当多的点;你也保证不会跳过任何比你已经找到的最好的点更近的点。

    所以对于您的代码,这看起来像

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    // sort centroid by x coordinate.
    min_index = -1;
    min_dist = numeric_limits<float>::max();

    // pick the start index. This works well if the points are evenly distributed.
    float min_x = centroids[0].pos[0];
    float max_x = centroids[num_centroids-1].pos[0];
    float cur_x = pt.pos[0];
    float t = (max_x - cur_x) / (max_x - min_x);
    // TODO clamp t between 0 and 1
    int start_index = int(t * float(num_centroids))

    // Forward search
    for (int i=start_index ; i < num_centroids; ++i)
    {
        const ClusterCentroid& cent = centroids[i];
        if (fabs(cent.pos[0] - pt.pos[0]) > min_i)
            // Everything to the right of this must be further min_dist, so break.
            // This is where the savings comes from!
            break;
        const float dist = ...;
        if (dist < min_dist)
        {
            min_dist = dist;
            min_index = i;
        }
    }

    // Backwards search
    for (int i=start_index ; i >= 0; --i)
    {
        // same as above
    }
    pt.min_dist = min_dist
    pt.min_index = min_index

    (请注意,这假设您正在计算点之间的距离,但您的程序集指示它是距离的平方。相应地调整中断条件)。

    构建树或对形心进行排序的开销很小,但是这应该通过使计算在更大的循环中更快(超过点数)来抵消。