关于python:为Enum的子类重载__init __()

Overload __init__() for a subclass of Enum

我正试图重载枚举子类的__init__()方法。奇怪的是,使用普通类的模式不再适用于枚举。

下面显示了使用普通类的所需模式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Integer:
    def __init__(self, a):
       """Accepts only int"""
        assert isinstance(a, int)
        self.a = a

    def __repr__(self):
        return str(self.a)


class RobustInteger(Integer):
    def __init__(self, a):
       """Accepts int or str"""
        if isinstance(a, str):
            super().__init__(int(a))
        else:
            super().__init__(a)


print(Integer(1))
# 1
print(RobustInteger(1))
# 1
print(RobustInteger('1'))
# 1

如果与枚举一起使用,则相同的模式将中断:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from enum import Enum
from datetime import date


class WeekDay(Enum):
    MONDAY = 0
    TUESDAY = 1
    WEDNESDAY = 2
    THURSDAY = 3
    FRIDAY = 4
    SATURDAY = 5
    SUNDAY = 6

    def __init__(self, value):
       """Accepts int or date"""
        if isinstance(value, date):
            super().__init__(date.weekday())
        else:
            super().__init__(value)


assert WeekDay(0) == WeekDay.MONDAY
assert WeekDay(date(2019, 4, 3)) == WeekDay.MONDAY
# ---------------------------------------------------------------------------
# TypeError                                 Traceback (most recent call last)
# /path/to/my/test/file.py in <module>()
#      27
#      28
# ---> 29 class WeekDay(Enum):
#      30     MONDAY = 0
#      31     TUESDAY = 1

# /path/to/my/virtualenv/lib/python3.6/enum.py in __new__(metacls, cls, bases, classdict)
#     208             enum_member._name_ = member_name
#     209             enum_member.__objclass__ = enum_class
# --> 210             enum_member.__init__(*args)
#     211             # If another member with the same value was already defined, the
#     212             # new member becomes an alias to the existing one.

# /path/to/my/test/file.py in __init__(self, value)
#      40             super().__init__(date.weekday())
#      41         else:
# ---> 42             super().__init__(value)
#      43
#      44

# TypeError: object.__init__() takes no parameters


你必须让_missing_钩子过载。WeekDay的所有实例都是在第一次定义类时创建的;WeekDay(date(...))是一个索引操作而不是创建操作,并且__new__最初查找绑定到整数0到6的预先存在的值。如果不成功,它将调用_missing_,您可以在其中将date对象转换为这样的整数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class WeekDay(Enum):
    MONDAY = 0
    TUESDAY = 1
    WEDNESDAY = 2
    THURSDAY = 3
    FRIDAY = 4
    SATURDAY = 5
    SUNDAY = 6

    @classmethod
    def _missing_(cls, value):
        if isinstance(value, date):
            return cls(value.weekday())
        return super()._missing_(value)

几个例子:

1
2
3
4
5
6
7
>>> WeekDay(date(2019,3,7))
<WeekDay.THURSDAY: 3>
>>> assert WeekDay(date(2019, 4, 1)) == WeekDay.MONDAY
>>> assert WeekDay(date(2019, 4, 3)) == WeekDay.MONDAY
Traceback (most recent call last):
  File"<stdin>", line 1, in <module>
AssertionError

(注意:在python 3.6之前,_missing_不可用。)

在3.6之前,您似乎可以覆盖EnumMeta.__call__进行相同的检查,但我不确定这是否会产生意外的副作用。(关于__call__的推理总是让我的头有点转。)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Silently convert an instance of datatime.date to a day-of-week
# integer for lookup.
class WeekDayMeta(EnumMeta):
    def __call__(cls, value, *args, **kwargs):
        if isinstance(value, date):
            value = value.weekday())
        return super().__call__(value, *args, **kwargs)

class WeekDay(Enum, metaclass=WeekDayMeta):
    MONDAY = 0
    TUESDAY = 1
    WEDNESDAY = 2
    THURSDAY = 3
    FRIDAY = 4
    SATURDAY = 5
    SUNDAY = 6


有一个更好的答案,但我无论如何都会发布这个,因为它可能有助于理解这个问题。

文档给出了以下提示:

EnumMeta creates them all while it is creating the Enum class itself,
and then puts a custom new() in place to ensure that no new ones
are ever instantiated by returning only the existing member instances.

所以我们必须等待重新定义__new__,直到类被创建。通过一些难看的修补,这通过了测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from enum import Enum
from datetime import date

class WeekDay(Enum):
    MONDAY = 0
    TUESDAY = 1
    WEDNESDAY = 2
    THURSDAY = 3
    FRIDAY = 4
    SATURDAY = 5
    SUNDAY = 6

wnew = WeekDay.__new__

def _new(cls, value):
    if isinstance(value, date):
        return wnew(cls, value.weekday()) # not date.weekday()
    else:
        return wnew(cls, value)

WeekDay.__new__ = _new

assert WeekDay(0) == WeekDay.MONDAY
assert WeekDay(date(2019, 3, 4)) == WeekDay.MONDAY # not 2019,4,3