关于numpy:通过公式在python中定义一个ndarray

Defining an ndarray in python via a formula

我有一个多维数组,以C=np.zeros([20,20,20,20])开头。然后我试图通过某种公式(在本例中是C(x)=(exp(-|x|^2))将一些值赋给c。以下代码工作正常,但速度非常慢。

1
2
3
4
5
it=np.nditer(C, flags=['multi_index'], op_flags=['readwrite'])
while not it.finished:
    diff=np.linalg.norm(np.array(it.multi_index))
    it[0]=np.exp(-diff**2)
    it.iternext()

你能用更快,甚至更多的Python般的方式做到这一点吗?


这是一种方法。

步骤1获取与代码中使用np.array(it.multi_index)计算的所有索引对应的所有组合。在这一点上,我们可以利用product from itertools

步骤2以矢量化方式对所有组合执行二级范数计算。

步骤3最后以元素方式执行C(x)=(exp(-|x|^2)

1
2
3
4
5
# Get combinations using itertools.product
combs = np.array(list(product(range(N), repeat=4)))

# Perform L2 norm and elementwise exponential calculations to get final o/p
out = np.exp(-np.sqrt((combs**2).sum(1))**2).reshape(N,N,N,N)

运行时测试和验证输出-

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
In [42]: def vectorized_app(N):
    ...:     combs = np.array(list(product(range(N), repeat=4)))
    ...:     return np.exp(-np.sqrt((combs**2).sum(1))**2).reshape(N,N,N,N)
    ...:
    ...: def original_app(N):
    ...:     C=np.zeros([N,N,N,N])
    ...:     it=np.nditer(C, flags=['multi_index'], op_flags=['readwrite'])
    ...:     while not it.finished:
    ...:         diff_n=np.linalg.norm(np.array(it.multi_index))
    ...:         it[0]=np.exp(-diff_n**2)
    ...:         it.iternext()
    ...:     return C
    ...:

In [43]: N = 10

In [44]: %timeit original_app(N)
1 loops, best of 3: 288 ms per loop

In [45]: %timeit vectorized_app(N)
100 loops, best of 3: 8.63 ms per loop

In [46]: np.allclose(vectorized_app(N),original_app(N))
Out[46]: True


所以看起来你只是不想把你的操作应用到每个元素的索引上?这个怎么样:

1
x = np.exp(-np.linalg.norm(np.indices([20,20,20,20]), axis=0)**2)

指数是一个非常巧妙的函数。对于更复杂的操作,还涉及mgrid和meshgrid。在本例中,因为您有4个维度,所以它返回一个具有形状(4,20,20,20,20)的数组。

而且纯麻木有点快。)

1
2
3
4
5
In [13]: timeit posted_code()
1 loops, best of 3: 843 ms per loop

In [14]: timeit np.exp(-np.linalg.norm(np.indices([20,20,20,20]), axis=0)**2)
100 loops, best of 3: 3.76 ms per loop

结果完全一样:

1
2
In [26]: np.all(C == x)
Out[26]: True