ID3决策树的Python实现以及可视化

算法介绍

ID3决策树是比较经典的决策树,在周志华的机器学习中,生成决策树的算法为:
在这里插入图片描述
算法的关键是如何选择最优划分属性,在ID3决策树中,用信息增益来指导决策树选择最优划分属性
首先定义信息熵为:
在这里插入图片描述
再定义信息增益为:
在这里插入图片描述
一般而言,信息增益越大,意味着使用属性a进行划分所获得的纯度提升越大,因此我们选择最大信息增益的属性作为最优划分属性。

Python实现思路

树的数据表示

既然要实现一棵树,首先要做的就是定义节点的数据结构,在C中,节点一般以结构体的形式存储,所以我们在Python中可以参考这一思路定义一个节点类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Node():
    """
    ID3决策树的节点
    parent -- 父节点
    sons -- 子节点集合,即在该节点最优划分属性下每个属性值的分支
    attrs -- 该节点下的最优划分属性
    parent_attrs_value -- 表示该节点是父节点哪一个属性的分支
    label -- 如果这个节点是叶子节点,则存放标签
    """
    def __init__(self, parent=None):
        self.parent = parent            
        self.sons = []                  
        self.attr = None                
        self.parent_attrs_value = None  
        self.label = None

但在实际操作中,使用这一方法给代码的调试增加了难度,同时不利于后面用Graphviz包实现决策树的可视化,因此本文考虑使用另一种数据结构表示树,就是Python中的字典,我们先来看看对于西瓜书中给出的一颗决策树,用字典是如何表示的:
西瓜书中的一颗决策树:
在这里插入图片描述
对应的Python字典表示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
tree = {<!-- -->'纹理':
            {<!-- -->'清晰':
                {<!-- -->'根蒂':
                    {<!-- -->'蜷缩':
                        {<!-- -->'label':'是'},
                    '稍蜷':
                        {<!-- -->'色泽':
                            {<!-- -->'青绿':
                                {<!-- -->'label':'是'},
                            '乌黑':
                                {<!-- -->'触感':
                                    {<!-- -->'硬滑':
                                        {<!-- -->'label':'是'},
                                    '软粘':
                                        {<!-- -->'label':'否'}}},
                            '浅白':
                                {<!-- -->'label':'是'}}},
                    '硬挺':
                        {<!-- -->'label':'否'}}},
            '稍糊':
                {<!-- -->'触感':
                    {<!-- -->'硬滑':
                        {<!-- -->'label':'否'},
                    '软粘':
                        {<!-- -->'label':'是'}}},
            '模糊':
                {<!-- -->'label':'否'}}}

如何可视化决策树

在本文中,使用Graphviz包进行决策树的可视化,这里是官网和文档
只需使用几条简单的代码便可将决策树的节点绘制出来:

1
2
3
4
g = graphviz.Digraph(name=,filename=, format='png')
g.node(name=, label=, fontname="Microsoft YaHei", shape=)
g.edge(tail_name, head_name, label=, fontname="Microsoft YaHei")
g.view()

要注意,如果决策树的信息是中文的,要在fontname参数中指定中文字体,不然会出现乱码

Python代码

DecesionTree.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import numpy as np
import scipy.io as sio
from collections import Counter
from graphviz import Digraph

class DecisionTree():
    """
    一个构建ID3决策树的类
    attrs -- 存放属性的字典, 字典中,键为属性名,值为属性的取值,最后一个属性为标签属性
    X -- 训练数据
    y -- 标签
    attr_idx -- 属性列索引
    tree -- 生成的决策树,用字典形式存放
    node_name -- 用于对决策树的可视化,在graphviz中对节点的命名
    """
    def __init__(self):
        self.attrs = None
        self.X = None
        self.y = None
        self.attr_idx = None
        self.tree = {<!-- -->}
        self.node_name = "0"


    def get_attrs(self, data):
        """
        对数据集进行处理,得到属性与对应的属性取值
        args:
        data -- 输入的数据矩阵, shape=(samples+1, features), dtype='<U?', 其中,第一行为属性,最后一列为标签
        returns:
        attrs -- 存放属性的字典, 字典中,键为属性名,值为属性的取值
        """
        attrs = {<!-- -->}
        for i in range(data.shape[1]):
            attrs_values = sorted(set(data[1:, i]))
            attrs[data[0][i]] = attrs_values

        self.attrs = attrs
        return attrs


    def generate_tree(self, data):
        """
        生成决策树
        args:
        data -- 输入的数据矩阵, shape=(samples+1, features+label), dtype='<U?', 其中,第一行为属性,最后一列为标签
        """
        self.X = data[1:, :-1]
        self.y = data[1:, -1]

        # 先创建一个不含label属性的纯变量属性字典
        pure_attrs = self.attrs.copy()
        del(pure_attrs['label'])
        # 构造一个只含属性名的列表
        attr_names = [attr_name for attr_name in pure_attrs.keys()]
        # 将属性名编号,方便查找其在数据中对应的列
        attr_idx = {<!-- -->}
        for num, attr in enumerate(attr_names):
            attr_idx[attr] = num
        self.attr_idx = attr_idx

        # 生成根节点
        self.tree['root_node'] = {<!-- -->}
        self._generate_tree(self.X, self.y, self.tree['root_node'], pure_attrs, attr_idx)
        self.tree = self.tree['root_node']
       

    def _generate_tree(self, X, y, node, attrs, attr_idx):
        """
        递归生成决策树
        args:
        X -- 输入的数据矩阵, shape=(samples, features), dtype='<U?'
        y -- 标签, shape=(samples, )
        parent_node -- 父节点,此次递归函数是父节点的某一个属性值的递归
        attrs -- 属性字典, 即从父节点分支到现在的节点时,还没有被划分的属性
        attr_idx -- 属性在数据中列索引
        """

        #--------- 如果训练集中样本全属于同一类别 ---------#
        if len(set(y.tolist())) == 1:
            node['label'] = y[0]
            return

        #-------- 如果属性集为空集或者训练集中样本在属性集上取值相同 ---------#
        # 判断训练集样本在属性集中取值是否相同
        same = True
        for i in range(X.shape[1]):
            if len(set(X[:, i].tolist())) > 1:
                same = False
       
        if not attrs or same:
            y_counter = Counter(y)
            most_y = y_counter.most_common()[0][0]
            node['label'] = most_y
            return

        #--------- 选择最优属性生成分支 ---------#
        # 选出最优划分属性
        optimal_attr = self.choose_optimal_attr(X, y, attrs, attr_idx)
        node[optimal_attr] = {<!-- -->}
        node = node[optimal_attr]
        # 对于最优划分属性下每个属性值
        for attr_value in attrs[optimal_attr]:
            # 生成分支
            node[attr_value] = {<!-- -->}
            # 令Dv表示X中在optimal_attr上取值为attr_value的样本子集
            Dv = X.copy()
            attr_value_idx = Dv[:, attr_idx[optimal_attr]] == attr_value
            Dv = Dv[attr_value_idx, :]
            y_Dv = y[attr_value_idx]
            Dv = np.delete(Dv, attr_idx[optimal_attr], 1)
            # 如果Dv为空
            if Dv.size == 0:
                # 将分支节点标记为叶节点,其类别标记为X中样本最多的类,即统计y
                y_counter = Counter(y)
                most_y = y_counter.most_common()[0][0]
                node[attr_value]['label'] = most_y
            else:
                # 更新属性字典
                new_attrs = attrs.copy()
                del(new_attrs[optimal_attr])
                # 更新属性列索引
                new_attr_names = [new_attr_name for new_attr_name in new_attrs.keys()]
                new_attr_idx = {<!-- -->}
                for num, attr in enumerate(new_attr_names):
                    new_attr_idx[attr] = num
                self._generate_tree(Dv, y_Dv, node[attr_value], new_attrs, new_attr_idx)


    def compute_Ent(self, y):
        """
        计算给出属性名列表所对应的所有样本的信息熵
        args:
        y -- 标签数组, shape=(samples, )
        return:
        Ent -- 样本的信息熵
        """
        Ent = 0
        m = np.size(y)
        for label in self.attrs['label']:
            pk = np.sum(y == label)
            pk = pk / m
            log2pk = np.log2(pk + 1e-8) # 防止算得0,导致返回nan
            Ent -= pk * log2pk
        return Ent


    def choose_optimal_attr(self, X, y, attrs, attr_idx):
        """
        选择最优划分属性 划分标准:属性的信息增益
        args:
        X -- 输入的数据矩阵, shape=(samples, features), dtype='<U?'
        y -- 标签, shape=(samples, )
        attrs -- 属性字典
        attr_idx -- 属性在数据中列索引
        returns:
        max_gain_attr -- 最大的信息增益对应的属性
        """
        # 计算当前所含属性对应所有样本的信息熵
        Ent = self.compute_Ent(y)
        m = np.size(y)
        # 记录当前最大的信息增益以及对应的属性
        max_gain = 0
        max_gain_attr = None
       
        # 计算每一个属性的信息增益
        for attr, idx in attr_idx.items():
            x = X[:, idx]
            gain = Ent
            # 计算一个属性中每个属性值的信息熵
            for attr_value in attrs[attr]:
                _y = y[x==attr_value]
                if _y.size != 0:
                    ent = self.compute_Ent(_y)
                else:
                    ent = 0
                gain -= np.size(_y) / m * ent
            if gain > max_gain:
                max_gain = gain
                max_gain_attr = attr
               
        return max_gain_attr


    def predict(self, predict_x):
        """
        预测样本结果
        args:
        predict_x -- 预测样本数据矩阵 shape=(samples, features)
        returns:
        predict_y -- 样本的预测结果 shape=(samples, )
        """
        s = predict_x.shape[0]
        predict_y = []
        for i in range(s):
            node = self.tree
            while(1):
                if 'label' in node.keys():
                    predict_y.append(node['label'])
                    break
                elif list(node.keys())[0] in self.attrs.keys():
                    attr = list(node.keys())[0]
                    idx = self.attr_idx[attr]
                    node = node[attr]
                else:
                    node = node[predict_x[i, idx]]
        return predict_y


    def tree_traversal(self, g, parent_node, parent_node_name, parent_attr, parent_attr_value):
        """
        对树进行遍历,生成可视化的节点
        g -- 要绘制的有向图
        parent_node -- 父节点
        parent_node_name -- 父节点在有向图中的代号
        parent_attr -- 父节点的属性
        parent_attr_value -- 父节点到该节点的属性值
        """
        if (parent_attr and parent_attr_value) is None:
            if 'label' in parent_node.keys():
                g.node(name=self.node_name, label=parent_node['label'], fontname="Microsoft YaHei")
                return
            else:
                attr = list(parent_node.keys())[0]
                node = parent_node[attr]
                parent_node_name = "0"
                for attr_value in node.keys():
                    self.tree_traversal(g, node[attr_value], parent_node_name, attr, attr_value)
        else:
            if 'label' in parent_node.keys():
                g.node(name=parent_node_name, label=parent_attr, fontname="Microsoft YaHei", shape='box')
                self.node_name = str(int(self.node_name) + 1)
                g.node(name=self.node_name, label=parent_node['label'], fontname="Microsoft YaHei")
                g.edge(parent_node_name, self.node_name, label=parent_attr_value, fontname="Microsoft YaHei")
            else:
                attr = list(parent_node.keys())[0]
                g.node(name=parent_node_name, label=parent_attr, fontname="Microsoft YaHei", shape='box')
                self.node_name = str(int(self.node_name) + 1)
                g.node(name=self.node_name, label=attr, fontname="Microsoft YaHei", shape='box')
                g.edge(parent_node_name, self.node_name, label=parent_attr_value, fontname="Microsoft YaHei")
                node = parent_node[attr]
                parent_node_name = self.node_name
                for attr_value in node.keys():
                    self.tree_traversal(g, node[attr_value], parent_node_name, attr, attr_value)


    def tree_visualize(self, file_name=None):
        """
        将决策树可视化
        args:
        file_name -- 若给出该参数,则将决策树保存为file_name的图片
        """
        if file_name:
            g = Digraph("Decision Tree", filename=file_name, format='png')
        else:
            g = Digraph("Decision Tree")
        self.tree_traversal(g, self.tree, None, None, None)
        g.view()


if __name__ == "__main__":
    pass

主函数,以西瓜树的西瓜数据集为例生成决策树,原数据集是Matlab的cell数组,并以mat文件存放,因此需要预处理一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import scipy.io as sio
from DecisionTree import DecisionTree

def preprocess():
    raw_data = sio.loadmat('watermelon.mat')
    raw_data = raw_data['watermelon']
    data = np.zeros(raw_data.shape, dtype='<U20')

    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            data[i, j] = raw_data[i, j][0]

    data[0, -1] = 'label'
    return data

def main_1():
    """
    完整决策树
    """
    data = preprocess()
    DTree = DecisionTree()
    attrs = DTree.get_attrs(data)
    DTree.generate_tree(data)
    DTree.tree_visualize('watermelob_tree')

def main_2():
    """
    留出两个样本作为测试集
    """
    data = preprocess()
    train_idx = np.delete(np.arange(0, 18), [8, 17])
    test_idx = [8, 17]
    train_data = data[train_idx, :]
    test_data = data[test_idx, :]
    test_X = test_data[:, :-1]
    test_y = test_data[:, -1]

    DTree = DecisionTree()
    DTree.get_attrs(train_data)
    DTree.generate_tree(train_data)
    predict_y = DTree.predict(test_X)
    print(predict_y)
    DTree.tree_visualize('watermelon_tree_2')

main_1()

最终生成的决策树图片为:
在这里插入图片描述
到这里我们就成功地用Python实现了ID3决策树!