关于oop:python中不同参数类型的方法重载

Method overloading for different argument type in python

我正在用python编写一个预处理器,其中的一部分与ast一起工作。

有一个render()方法负责将各种语句转换为源代码。

现在,我把它改成这样(简称):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def render(self, s):
   """ Render a statement by type."""

    # code block (used in structures)
    if isinstance(s, S_Block):
        # delegate to private method that does the work
        return self._render_block(s)

    # empty statement
    if isinstance(s, S_Empty):
        return self._render_empty(s)

    # a function declaration
    if isinstance(s, S_Function):
        return self._render_function(s)

    # ...

正如您所看到的,它很冗长,容易出错,代码也很长(我有更多种类的语句)。

理想的解决方案是(在Java语法中):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
String render(S_Block s)
{
    // render block
}

String render(S_Empty s)
{
    // render empty statement
}

String render(S_Function s)
{
    // render function statement
}

// ...

当然,Python不能这样做,因为它有动态类型。当我搜索如何模仿方法重载时,所有的答案都只是说"你不想在Python中这样做"。我想这在某些情况下是正确的,但在这里,kwargs实际上根本没有用处。

在没有可怕的公里长序列的情况下,如果类型检查为ifs,我将如何在python中执行此操作,如上图所示?还有,最好是用"Python式"的方法?

注意:可以有多个"渲染器"实现,它们以不同的方式呈现语句。因此,我不能将呈现代码移到语句中,只调用s.render()。它必须在渲染器类中完成。

(我发现了一些有趣的"访客"代码,但我不确定它是否真的是我想要的东西)。


如果您使用的是python 3.4(或者愿意为python 2.6+安装后端端口),那么您可以使用functools.singledispatch进行此*:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from functools import singledispatch

class S_Block(object): pass
class S_Empty(object): pass
class S_Function(object): pass


class Test(object):
    def __init__(self):
        self.render = singledispatch(self.render)
        self.render.register(S_Block, self._render_block)
        self.render.register(S_Empty, self._render_empty)
        self.render.register(S_Function, self._render_function)

    def render(self, s):
        raise TypeError("This type isn't supported: {}".format(type(s)))

    def _render_block(self, s):
        print("render block")

    def _render_empty(self, s):
        print("render empty")

    def _render_function(self, s):
        print("render function")


if __name__ =="__main__":
    t = Test()
    b = S_Block()
    f = S_Function()
    e = S_Empty()
    t.render(b)
    t.render(f)
    t.render(e)

输出:

1
2
3
render block
render function
render empty

>基于此gist的代码。


您正在寻找的重载语法可以使用guido van rossum的多方法修饰器来实现。

这里是可以修饰类方法(原始修饰普通函数)的多方法修饰器的变体。我已经命名了变体multidispatch,以将其从原始版本中消除歧义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import functools

def multidispatch(*types):
    def register(function):
        name = function.__name__
        mm = multidispatch.registry.get(name)
        if mm is None:
            @functools.wraps(function)
            def wrapper(self, *args):
                types = tuple(arg.__class__ for arg in args)
                function = wrapper.typemap.get(types)
                if function is None:
                    raise TypeError("no match")
                return function(self, *args)
            wrapper.typemap = {}
            mm = multidispatch.registry[name] = wrapper
        if types in mm.typemap:
            raise TypeError("duplicate registration")
        mm.typemap[types] = function
        return mm
    return register
multidispatch.registry = {}

它可以这样使用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Foo(object):
    @multidispatch(str)
    def render(self, s):
        print('string: {}'.format(s))
    @multidispatch(float)
    def render(self, s):
        print('float: {}'.format(s))
    @multidispatch(float, int)
    def render(self, s, t):
        print('float, int: {}, {}'.format(s, t))

foo = Foo()
foo.render('text')
# string: text
foo.render(1.234)
# float: 1.234
foo.render(1.234, 2)
# float, int: 1.234, 2

上面的演示代码显示了如何根据参数的类型重载Foo.render方法。

此代码搜索完全匹配的类型,而不是检查isinstance关系。可以修改它来处理这个问题(以查找O(n)而不是O(1)为代价),但是由于听起来您不需要这个,所以我将把代码保留为这个简单的形式。


想要这个工作吗?

1
2
3
4
5
6
7
self.map = {
            S_Block : self._render_block,
            S_Empty : self._render_empty,
            S_Function: self._render_function
}
def render(self, s):
    return self.map[type(s)](s)

将对类对象的引用作为字典中的键,并将其值作为要调用的函数对象,这将使代码更短,并且不易出错。这里唯一可能发生错误的地方就是字典的定义。或者你的一个内部功能。


使用PEP-443中定义的装饰器,使用functools.singleDispatch的替代实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from functools import singledispatch

class S_Unknown: pass
class S_Block: pass
class S_Empty: pass
class S_Function: pass
class S_SpecialBlock(S_Block): pass

@singledispatch
def render(s, **kwargs):
  print('Rendering an unknown type')

@render.register(S_Block)
def _(s, **kwargs):
  print('Rendering an S_Block')

@render.register(S_Empty)
def _(s, **kwargs):
  print('Rendering an S_Empty')

@render.register(S_Function)
def _(s, **kwargs):
  print('Rendering an S_Function')

if __name__ == '__main__':
  for t in [S_Unknown, S_Block, S_Empty, S_Function, S_SpecialBlock]:
    print(f'Passing an {t.__name__}')
    render(t())

此输出

1
2
3
4
5
6
7
8
9
10
Passing an S_Unknown
Rendering an unknown type
Passing an S_Block
Rendering an S_Block
Passing an S_Empty
Rendering an S_Empty
Passing an S_Function
Rendering an S_Function
Passing an S_SpecialBlock
Rendering an S_Block

我更喜欢这个版本,因为它与使用isinstance()的实现具有相同的行为:当您传递一个s specialBlock时,它会将它传递给接受s u块的渲染器。

可利用性

正如Dano在另一个答案中提到的,这在Python3.4+中有效,并且有一个针对Python2.6+的后端端口。

如果您有python 3.7+,那么register()属性支持使用类型注释:

1
2
3
@render.register
def _(s: S_Block, **kwargs):
  print('Rendering an S_Block')

。注释

我能看到的一个问题是,你必须把s作为一个位置论点,这意味着你不能做render(s=S_Block())

由于single_dispatch使用第一个参数的类型来确定要调用哪个版本的render(),这将导致类型错误-"render需要至少一个位置参数"(cf源代码)。

实际上,如果只有一个关键字参数,我认为应该可以使用它。如果您真的需要这样做,那么您可以做类似于这个答案的事情,它用不同的包装器创建一个定制的装饰器。这也是Python的一个很好的特性。


要在@unutbu的答案中添加一些性能度量:

1
2
3
4
5
6
7
8
9
10
@multimethod(str)
def foo(bar: str) -> int:
    return 'string: {}'.format(bar)

@multimethod(float)
def foo(bar: float) -> int:
    return 'float: {}'.format(bar)

def foo_simple(bar):
    return 'string: {}'.format(bar)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import time

string_type ="test"
iterations = 10000000

start_time1 = time.time()
for i in range(iterations):
    foo(string_type)
end_time1 = time.time() - start_time1


start_time2 = time.time()
for i in range(iterations):
    foo_simple(string_type)
end_time2 = time.time() - start_time2

print("multimethod:" + str(end_time1))
print("standard:" + str(end_time2))

返回:

1
2
> multimethod: 16.846999883651733
> standard:     4.509999990463257