关于matlab:编码分类树…如何存储?

Coding a classification tree… how to store?

我在这里找一个基本的伪代码大纲。

我的目标是从头开始编写一个分类树(我在学习机器学习,希望获得直觉)。但是我的培训数据是巨大的:40000个例子和1000个特性。考虑到所需拆分数量的上限是240000,我将失去如何跟踪所有这些分区数据集的能力。

假设我从完整的数据集开始并进行一次拆分。然后,我可以将位于拆分一侧的20000个示例保存到一个数据集中,然后重新运行拆分算法以找到该数据集的贪婪拆分。然后说我继续这样做,沿着树的最左边的树枝分了几十次。

当我对所有最左边的部分感到满意时,那又是什么呢?如何存储最多240000个单独的子集?当我对一个测试示例进行分类时,如何跟踪我所采取的所有拆分?对我来说,代码的组织是没有意义的。


感谢@natan提供详细答案。

但是,如果我理解正确,您关心的主要问题是如何在每个培训样本通过决策树传播时有效地跟踪它。

这很容易做到。

您只需要一个大小为N=40000的向量,每个训练样本都有一个条目。这个向量将告诉您每个样本在树中的位置。我们称之为矢量assoc

如何使用这个向量?

在我看来,最优雅的方法是制作uint32型的assoc,并使用位编码每个训练样本在树中的传播。

assoc(k)中的每个位代表树的某个深度,如果设置了该位(1),则表示样本k向右,否则表示样本k向左。

如果你决定采用这种策略,你会发现下面的matlab命令有用的bitgetbitsetbitshift和一些其他位函数。

让我们考虑下面的树

1
2
3
4
5
       root
      /    \
     a      b
           / \
          c   d

因此,对于所有进入节点A的示例,它们的assoc值是00b,因为它们离开了根(对应于零,至少是最低有效位(lsb))。

所有进入叶节点C的例子,它们的assoc值是01b-它们在根处向右(lsb=1),然后向左(第2位=0)。

最后,所有进入叶节点d的例子,它们的assoc值是11b,它们的分支太右。

现在,如何找到通过节点B的所有示例?

这很容易!

1
>> selNodeB = bitand( assoc, 1 );

所有LSB为1的节点在根部右转通过节点B。


如果你认为有一种方法可以存储2^40000位,你还没有意识到这个数字有多大,而且你在大约10000个数量级上是错误的。检查matlab的classregtree文档。

我从@amro的详细答案中复制了以下内容:

"以下是分类树模型的一些常见参数:

  • x:数据矩阵,行是实例,列是预测属性
  • Y:列向量,每个实例的类标签
  • 分类:指定哪些属性是离散类型(而不是连续类型)
  • 方法:生成分类树还是回归树(取决于类类型)
  • 名称:为属性命名
  • 修剪:启用/禁用减少的错误修剪
  • minparent/minleaf:如果要进一步拆分,允许指定节点中的最小实例数。
  • nvartosample:用于随机树(考虑每个节点随机选择的k个属性)
  • 权重:指定权重实例
  • 成本:指定成本矩阵(各种错误的惩罚)
  • SplitCriteria:用于在每次拆分时选择最佳属性的条件。我只熟悉基尼指数,它是信息获取标准的一个变化。
  • 先验概率:明确指定先验概率,而不是根据训练数据计算。

一个完整的例子来说明这个过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
%# load data
load carsmall

%# construct predicting attributes and target class
vars = {'MPG' 'Cylinders' 'Horsepower' 'Model_Year'};
x = [MPG Cylinders Horsepower Model_Year];
y = strcat(Origin,{});

%# train classification decision tree
t = classregtree(x, y, 'method','classification', 'names',vars, ...
                'categorical', [2 4], 'prune','off');
view(t)

%# test
yPredicted = eval(t, x);
cm = confusionmat(y,yPredicted);           %# confusion matrix
N = sum(cm(:));
err = ( N-sum(diag(cm)) ) / N;             %# testing error

%# prune tree to avoid overfitting
tt = prune(t, 'level',2);
view(tt)

%# predict a new unseen instance
inst = [33 4 78 NaN];
prediction = eval(tt, inst)

树http://img40.imageshack.us/img40/6994/greenshot20091225024654.png