关于性能:Python / Numpy – 快速查找最接近某些值的数组中的索引

Python/Numpy - Quickly Find the Index in an Array Closest to Some Value

我有一个值数组t,它总是按递增顺序排列(但不总是等距排列)。我还有一个单值x。我需要在t中找到索引,这样t[索引]最接近x。对于xt.max(),函数必须返回最大索引(或-1)。

我已经编写了两个函数来完成这项工作。第一个,f1,在这个简单的计时测试中要快得多。但我喜欢第二条线只是一条线。这个计算将在一个大数组上进行,可能每秒进行多次。

有人能想出其他一些与第一个函数具有可比时间的函数,但是代码看起来更清晰吗?不如先快一点怎么样(速度最重要)?

谢谢!

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
import timeit

t = np.arange(10,100000)         # Not always uniform, but in increasing order
x = np.random.uniform(10,100000) # Some value to find within t

def f1(t, x):
   ind = np.searchsorted(t, x)   # Get index to preserve order
   ind = min(len(t)-1, ind)      # In case x > max(t)
   ind = max(1, ind)             # In case x < min(t)
   if x < (t[ind-1] + t[ind]) / 2.0:   # Closer to the smaller number
      ind = ind-1
   return ind

def f2(t, x):
   return np.abs(t-x).argmin()

print t,           '
'
, x,           '
'

print f1(t, x),    '
'
, f2(t, x),    '
'

print t[f1(t, x)], '
'
, t[f2(t, x)], '
'


runs = 1000
time = timeit.Timer('f1(t, x)', 'from __main__ import f1, t, x')
print round(time.timeit(runs), 6)

time = timeit.Timer('f2(t, x)', 'from __main__ import f2, t, x')
print round(time.timeit(runs), 6)


这看起来要快得多(对我来说,python 3.2-win32,numpy 1.6.0):

1
2
3
4
5
6
from bisect import bisect_left
def f3(t, x):
    i = bisect_left(t, x)
    if t[i] - x > 0.5:
        i-=1
    return i

输出:

1
2
3
4
5
6
7
8
9
10
11
[   10    11    12 ..., 99997 99998 99999]
37854.22200356027
37844
37844
37844
37854
37854
37854
f1 0.332725
f2 1.387974
f3 0.085864


np.searchsorted是二进制搜索(每次将数组分成两半)。因此,您必须以一种方式实现它,它返回小于x的最后一个值,而不是返回零。

看看这个算法(从这里):

1
2
3
4
5
6
7
8
9
10
11
12
13
def binary_search(a, x):
    lo=0
    hi = len(a)
    while lo < hi:
        mid = (lo+hi)//2
        midval = a[mid]
        if midval < x:
            lo = mid+1
        elif midval > x:
            hi = mid
        else:
            return mid
    return lo-1 if lo > 0 else 0

刚刚替换了最后一行(是return -1)。还更改了参数。

因为循环是用Python编写的,所以可能比第一个慢…(无基准)


使用搜索排序:

1
2
3
4
t = np.arange(10,100000)         # Not always uniform, but in increasing order
x = np.random.uniform(10,100000)

print t.searchsorted(x)

编辑:

啊,是的,我知道你在F1就是这么做的。也许下面的f3比f1更容易阅读。

1
2
3
4
5
6
7
8
9
10
def f3(t, x):
    ind = t.searchsorted(x)
    if ind == len(t):
        return ind - 1 # x > max(t)
    elif ind == 0:
        return 0
    before = ind-1
    if x-t[before] < t[ind]-x:
        ind -= 1
    return ind