65.9K
CodeProject 正在变化。 阅读更多。
Home

实现机器学习的分步指南 IX - 树回归

starIconstarIconstarIconstarIconstarIcon

5.00/5 (1投票)

2019年5月29日

CPOL

2分钟阅读

viewsIcon

4777

易于实现的机器学习

引言

在现实世界中,有些关系不是线性的。因此,不适合将线性回归应用于这些问题的分析。为了解决这个问题,我们可以采用树回归。树回归的主要思想是将问题分解为更小的子问题。如果子问题是线性的,我们可以将所有子问题的模型组合起来,得到整个问题的回归模型。

回归模型

树回归类似于决策树,包括特征选择、生成回归树和回归。

特征选择

在决策树中,我们根据信息增益选择特征。然而,对于回归树,预测值是连续的,这意味着每个样本的回归标签几乎是唯一的。因此,经验熵缺乏表征能力。所以,我们利用平方误差作为特征选择的标准,即:

\sum_{x_{i}\in R_{m}}\left(y_{i}-f\left(x_{i}\right)\right)^{2}

其中 Rm 是回归树划分的空间,f(x) 由下式给出:

f\left(x\right)=\sum_{m=1}^{M}c_{m}I\left(x\in R_{m}\right)

因此,无论样本的特征如何,同一空间内的输出都是相同的。Rm 的输出是空间内所有样本回归标签的平均值,即:

c_{m}=\arg \left(y_{i}|x_{i}\in R_{m}\right)

回归树的特征选择类似于决策树,旨在最小化损失函数,即:

\min\limits_{j,s}\left[\min\limits_{c_{1}}\sum_{x_{i}\in R_{1}\left(j,s\right)}\left(y_{i}-c_{1}\right)+\min\limits_{c_{2}}\sum_{x_{i}\in R_{2}\left(j,s\right)}\left(y_{i}-c_{2}\right)\right]

生成回归树

我们首先定义一个数据结构来保存树节点

class RegressionNode():    
    def __init__(self, index=-1, value=None, result=None, right_tree=None, left_tree=None):
        self.index = index
        self.value = value
        self.result = result
        self.right_tree = right_tree
        self.left_tree = left_tree

与决策树类似,假设我们选择了最佳特征及其对应值 (j, s),然后我们通过以下方式分割数据:

R_{1}\left(j,s\right)=\left\{ x|x^{(j)}\leq s\right\}

R_{2}\left(j,s\right)=\left\{ x|x^{(j)}> s\right\}

并且每个二元分割的输出是:

c_{m}=\frac{1}{N_{m}}\sum_{x_{i}\in R_{m}\left(j,s\right)}y_{i},x\in R_{m},m=1,2

回归树的生成与决策树几乎相同,这里不再赘述。您可以阅读 逐步指南:实现机器学习 II - 决策树 以获取更多详细信息。如果您仍然有疑问,请与我联系。我将很乐意帮助您解决关于回归树的任何问题。

    def createRegressionTree(self, data):
        # if there is no feature
        if len(data) == 0:
            self.tree_node = treeNode(result=self.getMean(data[:, -1]))
            return self.tree_node

        sample_num, feature_dim = np.shape(data)

        best_criteria = None
        best_error = np.inf
        best_set = None
        initial_error = self.getVariance(data)

        # get the best split feature and value
        for index in range(feature_dim - 1):
            uniques = np.unique(data[:, index])
            for value in uniques:
                left_set, right_set = self.divideData(data, index, value)
                if len(left_set) < self.N or len(right_set) < self.N:
                    continue
                new_error = self.getVariance(left_set) + self.getVariance(right_set)
                if new_error < best_error:
                    best_criteria = (index, value)
                    best_error = new_error
                    best_set = (left_set, right_set)

        if best_set is None:
            self.tree_node = treeNode(result=self.getMean(data[:, -1]))
            return self.tree_node
        # if the descent of error is small enough, return the mean of the data
        elif abs(initial_error - best_error) < self.error_threshold:
            self.tree_node = treeNode(result=self.getMean(data[:, -1]))
            return self.tree_node
        # if the split data is small enough, return the mean of the data
        elif len(best_set[0]) < self.N or len(best_set[1]) < self.N:
            self.tree_node = treeNode(result=self.getMean(data[:, -1]))
            return self.tree_node
        else:
            ltree = self.createRegressionTree(best_set[0])
            rtree = self.createRegressionTree(best_set[1])
            self.tree_node = treeNode(index=best_criteria[0], \
                             value=best_criteria[1], left_tree=ltree, right_tree=rtree)
            return self.tree_node

回归

回归的原理类似于二叉搜索树,即:**将节点中存储的特征值与测试对象的相应特征值进行比较。然后,递归地转向左子树或右子树**,如下所示:

    def classify(self, sample, tree):
        if tree.result is not None:
            return tree.result
        else:
            value = sample[tree.index]
            if value >= tree.value:
                branch = tree.right_tree
            else:
                branch = tree.left_tree
            return self.classify(sample, branch)

结论与分析

分类树和回归树可以组合成分类与回归树(CART)。实际上,在生成树之后,存在剪枝过程。我们跳过它们,因为它们有点复杂,而且并非总是有效的。最后,让我们将我们的回归树与 Sklearn 中的树进行比较,检测性能如下所示:

Sklearn 树回归性能

我们的树回归性能

我们的树回归比 Sklearn 的花费时间稍长。

本文相关的代码和数据集可以在 MachineLearning 中找到。

历史

  • 2019年5月29日:初始版本
© . All rights reserved.