用于在Python中重载的装饰器

Decorator for overloading in Python

我知道编写关心参数类型的函数并不是皮索尼克式的,但有些情况下,由于处理方式不同,根本无法忽略类型。

在您的函数中进行一系列的isinstance检查是很难看的;是否有可用的函数修饰器来启用函数重载?像这样:

1
2
3
4
5
6
7
@overload(str)
def func(val):
    print('This is a string')

@overload(int)
def func(val):
    print('This is an int')

更新:

以下是我对David Zaslavsky答案的一些评论:

With a few modification▼显示, this will suit my purposes pretty well. One other limitation I noticed in your implementation, since you use func.__name__ as the dictionary key, you are prone to name collisions between modules, which is not always desirable. [cont'd]

[cont.] For example, if I have one module that overloads func, and another completely unrelated module that also overloads func, these overloads will collide because the function dispatch dict is global. That dict should be made local to the module, somehow. And not only that, it should also support some kind of 'inheritance'. [cont'd]

[cont.] By 'inheritance' I mean this: say I have a module first with some overloads. Then two more modules that are unrelated but each import first; both of these modules add new overloads to the already existing ones that they just imported. These two modules should be able to use the overloads in first, but the new ones that they just added should not collide with each other between modules. (This is actually pretty hard to do right, now that I think about it.)

其中一些问题可以通过稍微更改decorator语法来解决:

第一份.py

1
2
3
4
5
6
7
@overload(str, str)
def concatenate(a, b):
    return a + b

@concatenate.overload(int, int)
def concatenate(a, b):
    return str(a) + str(b)

二年级

1
2
3
4
5
from first import concatenate

@concatenate.overload(float, str)
def concatenate(a, b):
    return str(a) + b


快速回答:Pypi上有一个重载包,它比我在下面描述的更可靠地实现了这一点,尽管使用的语法略有不同。声明它只适用于python 3,但看起来只有轻微的修改(如果有的话,我还没有尝试过)才能使它适用于python 2。

长答案:在可以重载函数的语言中,函数的名称(字面上或有效地)由有关其类型签名的信息进行扩充,无论是在定义函数时还是在调用函数时。当编译器或解释器查找函数定义时,它使用声明的名称和参数类型来解析要访问的函数。因此,在Python中实现重载的逻辑方法是实现一个包装器,该包装器使用声明的名称和参数类型来解析函数。

下面是一个简单的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from collections import defaultdict

def determine_types(args, kwargs):
    return tuple([type(a) for a in args]), \
           tuple([(k, type(v)) for k,v in kwargs.iteritems()])

function_table = defaultdict(dict)
def overload(arg_types=(), kwarg_types=()):
    def wrap(func):
        named_func = function_table[func.__name__]
        named_func[arg_types, kwarg_types] = func
        def call_function_by_signature(*args, **kwargs):
            return named_func[determine_types(args, kwargs)](*args, **kwargs)
        return call_function_by_signature
    return wrap

应使用两个可选参数调用overload,一个元组表示所有位置参数的类型,一个元组表示所有关键字参数的名称类型映射。下面是一个使用示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
>>> @overload((str, int))
... def f(a, b):
...     return a * b

>>> @overload((int, int))
... def f(a, b):
...     return a + b

>>> print f('a', 2)
aa
>>> print f(4, 2)
6

>>> @overload((str,), (('foo', int), ('bar', float)))
... def g(a, foo, bar):
...     return foo*a + str(bar)

>>> @overload((str,), (('foo', float), ('bar', float)))
... def g(a, foo, bar):
...     return a + str(foo*bar)

>>> print g('a', foo=7, bar=4.4)
aaaaaaa4.4
>>> print g('b', foo=7., bar=4.4)
b30.8

这其中的缺点包括

  • 它实际上并没有检查应用decorator的函数是否与提供给decorator的参数兼容。你可以写

    1
    2
    3
    @overload((str, int))
    def h():
        return 0

    调用函数时会出错。

  • 它不能优雅地处理与传递的参数类型对应的不存在重载版本的情况(这将有助于引发更具描述性的错误)。

  • 它区分命名参数和位置参数,因此

    1
    g('a', 7, bar=4.4)

    不起作用。

  • 使用它涉及到许多嵌套的括号,如g的定义中所述。
  • 正如注释中提到的,这不处理在不同模块中具有相同名称的函数。

我想,所有这些都可以用足够的手段加以补救。特别是,通过将调度表存储为从装饰器返回的函数的属性,可以很容易地解决名称冲突问题。但正如我所说,这只是一个简单的例子,演示了如何做到这一点的基础。


这并不能直接回答你的问题,但是如果你真的想要一个类似于不同类型的重载函数的东西,并且(非常正确地)不想使用isInstance,那么我建议如下:

1
2
3
4
5
6
7
def func(int_val=None, str_val=None):
    if sum(x != None for x in (int_val, str_val)) != 1:
        #raise exception - exactly one value should be passed in
    if int_val is not None:
        print('This is an int')
    if str_val is not None:
        print('This is a string')

在使用中,目的是显而易见的,而且甚至不需要不同的选项来具有不同的类型:

1
2
func(int_val=3)
func(str_val="squirrel")