Simple sum-function in Python with numba doesn't compute
我正在尝试学习python和numba,我不明白为什么下面的代码不能在i python/jupyter中计算:
1 2 3 4 5 6 7 8 9 | from numba import * sample_array = np.arange(10000.0) @jit('float64(float64, float64)') def sum(x, y): return x + y sum(sample_array, sample_array) |
TypeError Traceback (most recent call last)
in ()
----> 1 sum(sample_array, sample_array)C:\Users***\AppData\Local\Continuum\Anaconda\lib\site-packages
umba\dispatcher.pyc in _explain_matching_error(self, *args, **kws)
201 msg = ("No matching definition for argument type(s) %s"
202 % ', '.join(map(str, args)))
--> 203 raise TypeError(msg)
204
205 def repr(self):TypeError: No matching definition for argument type(s) array(float64, 1d, C), array(float64, 1d, C)
您正在传入数组,但JIT签名需要标量浮点数。请尝试以下操作:
1 2 3 | @jit('float64[:](float64[:], float64[:])') def sum(x, y): return x + y |
我的建议是,看看您是否可以避免不指定类型,而只使用裸露的
1 2 3 4 5 6 7 8 9 | @jit(nopython=True) def sum(x, y): return x + y In [13]: sum(1,2) Out[13]: 3 In [14]: sum(np.arange(5),np.arange(5)) Out[14]: array([0, 2, 4, 6, 8]) |
我的经验是,添加这些类型很少会给性能带来任何好处。