百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术资源 > 正文

Python机器学习系列之scikit-learn决策树回归问题简单实践

off999 2024-11-26 07:24 24 浏览 0 评论

1.决策树概述

在前面几期的章节中

我们分别简单介绍了一下决策树DecisionTree的原理,以及使用一个简单的案例介绍了下决策树是如何处理简单分类问题的。通过前面两个章节简单地学习,我们再次回顾下决策树的一些优缺点,总结如下:

1.1决策树的优点是:

  • 易于理解和解释,树可以被可视化;
  • 几乎不需要数据准备。其他算法通常需要数据标准化,需要创建虚拟变量并删除缺失值。但是,请注意,此模块不支持缺失值。
  • 使用树的成本(即预测数据)是用于训练树的数据点数的对数。
  • 能够处理数值型和分类型数据。其他技术通常专门分析只有一种类型变量的数据集。
  • 能够处理多输出问题。
  • 使用白盒模型。如果给定的情况在模型中是可以观察到的,那么对条件的解释就很容易用布尔逻辑来解释。相反,在黑箱模型中(例如,在人工神经网络中),结果可能很难解释。
  • 可以使用统计测试验证模型。这样就有可能对模型的可靠性作出解释。
  • 即使它的假设在某种程度上被生成数据的真实模型所违背,它也表现得很好。

1.2决策树的缺点包括:

  • 决策树学习器可以创建过于复杂的树,不能很好地概括数据。这就是所谓的过拟合。为了避免这个问题,必须设置剪枝、设置叶节点所需的最小样本数或设置树的最大深度等机制。
  • 决策树可能是不稳定的,因为数据中的小变化可能导致生成完全不同的树。通过集成决策树来缓解这个问题。
  • 学习最优决策树的问题在最优性的几个方面都是NP-complete的,甚至对于简单的概念也是如此。因此,实际的决策树学习算法是基于启发式算法,如贪婪算法,在每个节点上进行局部最优决策。这种算法不能保证返回全局最优决策树。这可以通过训练多棵树再集成一个学习器来缓解,其中特征和样本被随机抽取并替换。
  • 有些概念很难学习,因为决策树不能很容易地表达它们,例如异或、奇偶校验或多路复用器问题。
  • 如果某些类占主导地位,则决策树学习者会创建有偏见的树。因此,建议在拟合决策树之前平衡数据集。

1.3决策树相关概念

我们先来回归一下决策树的几个相关概念:

节点

说明

根节点

没有进边,有出边

中间节点

既有进边也有出边,但进边有且只有一条,出边也可以有很多条

叶节点

只有进边,没有出边,进边有且只有一条,每个页节点都是一个类别标签

父节点和字节点

在两个相连的节点中,更靠近根节点的是父节点,另一个则是子节点,两者是相对的

2.CART算法

2.1什么是CART算法?

CART是英文Classification And Regression Tree的简写,又称为分类回归树。从它的名字我们就可以看出,它是一个很强大的算法,既可以用于分类还可以用于回归,所以非常值得我们来学习。

CART算法使用的就是二元切分法,这种方法可以通过调整树的构建过程,使其能够处理连续型变量。

具体的处理方法是:如果特征值大于给定值就走左子树,否则就走右子树。

CART算法分为两步:

  • 决策树生成:递归地构建二叉决策树的过程,基于训练数据集生成决策树,生成的决策树要尽量的大;自上而下从根开始构建节点,在每个节点处要选择一个最好的属性来分裂,使得字节点中的训练集尽量的纯;
  • 决策树剪枝:用验证数据集对已生成的树进行剪枝并选择最优子树,这时损失函数最小作为剪枝的标准;

不同的算法使用不同的指标来定义“最好”:

算法

指标

ID3

信息增益

  • 集合D的经验熵H(D)与特征A给定条件下D的经验条件熵H(D|A)之差;
  • 信息增益g(D,A)=H(D)-H(D|A);

C4.5

信息增益比

  • 信息增益g(D,A)与训练数据集D关于特征A的值的熵之比;
  • 信息增益比;

CART

基尼系数

  • 作为分类树时,使用gini系数来划分分支,gini(D)表示集合D的不确定性,基尼系数gini(D,A)表示经过A=a分割后集合D的不确定性;

三种方法本质上都相同,在类别分布均衡时达到最大值,而当所有记录都属于同一个类别时达到最小值。换而言之,在纯度较高时三个指数均较低,而当纯度较低时,三个指数都比较大,且可以计算得出,熵在0-1区间内分布,而Gini指数和分类误差均在0-0.5区间内分布。

CART树的构建过程:首先找到最佳的列来切分数据集,每次都执行二元切分法,如果特征值大于给定值就走左子树,否则就走右子树,当节点不能再分时就将该节点保存为叶节点。

3.回归树的sklearn实现

3.1 DecisionTreeRegressor

class sklearn.tree.DecisionTreeRegressor (criterion='mse', splitter='best', max_depth=None,
min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None,
random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, presort=False)

几乎所有参数,属性及接口都和分类树一模一样。需要注意的是,在回归树种,没有标签分布是否均衡的问题,因此没有class_weight这样的参数。

3.2 重要参数,属性和接口

回归树衡量分枝质量的指标,支持的标准有三种:

  1. 输入"mse"使用均方误差mean squared error(MSE),父节点和叶子节点之间的均方误差的差额将被用来作为特征选择的标准,这种方法通过使用叶子节点的均值来最小化L2损失;
  2. 输入“friedman_mse”使用费尔德曼均方误差,这种指标使用弗里德曼针对潜在分枝中的问题改进后的均方误差;
  3. 输入"mae"使用绝对平均误差MAE(mean absolute error),这种指标使用叶节点的中值来最小化L1损失;

属性中最重要的依然是feature_importances_,接口依然是apply, fit, predict, score最核心。

在回归树中,MSE不只是我们的分枝质量衡量指标,也是我们最常用的衡量回归树回归质量的指标,当我们在使用交叉验证,或者其他方式获取回归树的结果时,我们往往选择均方误差作为我们的评估(在分类树中这个指标是score代表的预测准确率)。在回归中,我们追求的是,MSE越小越好。

然而,回归树的接口score返回的是R平方,并不是MSE。

虽然均方误差永远为正,但是sklearn当中使用均方误差作为评判标准时,却是计算”负均方误差“(neg_mean_squared_error)。这是因为sklearn在计算模型评估指标的时候,会考虑指标本身的性质,均方误差本身是一种误差,所以被sklearn划分为模型的一种损失(loss),因此在sklearn当中,都以负数表示。真正的均方误差MSE的数值,其实就是neg_mean_squared_error去掉负号的数字。

3.3 简单看看回归树是怎样工作的

接下来我们到二维平面上来观察决策树是怎样拟合一条曲线的。我们用回归树来拟合正弦曲线,并添加一些噪声来观察回归树的表现。

import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt

导入处理二维矩阵的numpy,回归树DecisionTreeRegressor以及画图的matplotlib等必要的模块;

rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))

使用随机数函数生成带有噪声的数据集;

# 建立回归树模型并且训练
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X, y)
regr_2.fit(X, y)

# 使用回归树模型进行预测
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)

建立两个回归树模型,回归树的深度分别是2和5,分别用上面生成的数据集训练这两个回归树模型,然后用一组测试数据集分别对这两个回归树模型进行预测。

plt.figure(figsize=(16, 9))
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.savefig('DecisionTreeRegressor.jpg')
plt.show()

使用matplotlib模块将上面生成的数据集和回归树预测后的数据进行绘制到一张图像中,运行效果如下:

通过上面运行效果图我们可以清楚地发现,回归树学习了近似正弦曲线的局部线性回归。我们可以看到,如果树的最大深度(由max_depth参数控制)设置得太高,则决策树学习得太精细,它从训练数据中学了很多细节,包括噪声得呈现,从而使模型偏离真实的正弦曲线,形成过拟合。

3.4简单探索下回归树DecisionTreeRegressor模型DecisionTreeRegressor

我们使用sklearn自带的交叉验证函数来看看我们上面建立的两棵回归树的均方误差是怎样的?

from sklearn.model_selection import cross_val_score

val_score1 = cross_val_score(regr_1, X, y, cv=10, scoring = "neg_mean_squared_error")
val_score1

val_score2 = cross_val_score(regr_2, X, y, cv=10, scoring = "neg_mean_squared_error")
val_score2

我们分别对回归树模型regr_1和regr_2进行10次交叉验证,验证指标为均方误差,该函数返回一个包含10个浮点数的数组,数组里面的每一个浮点数表示一次验证的均方误差值。我们再对10次计算的均方误差进行一次求均值运输,就能够得到一个浮点数。运行效果如下:

交叉验证是用来观察模型的稳定性的一种方法,我们将数据划分为n份,依次使用其中一份作为测试集,其他n-1份作为训练集,多次计算模型的精确性来评估模型的平均准确程度。训练集和测试集的划分会干扰模型的结果,因此用交叉验证n次的结果求出的平均值,是对模型效果的一个更好的度量。

下面我们通过画一组学习曲线,来看看上面构建的回归树DecisionTreeRegressor和该回归树在上面我们生成的数据集训练后,max_depth与均方误差的关系

scores=[]
for i in range(10):
    clf = DecisionTreeRegressor(max_depth=i+1,
                                 random_state=10)
    val_score = -cross_val_score(clf, X, y, cv=10, scoring = "neg_mean_squared_error").mean()
    scores.append(val_score)

我们经过10次循环,分别建立10棵max_depth不一样的回归树模型,每一颗回归树进行10次交叉验证,并且将交叉验证的结果进行求均值。

plt.plot(range(1,11),scores,color="red",label="max_depth")
plt.ylabel("neg_mean_squared_error")
plt.grid()
plt.legend()
plt.savefig('Regressor_max_depth.jpg')
plt.show()

使用matplotlib模块绘制曲线,上面代码块运行效果如下图

通过上图我们可以清晰的知道,当回归树其他参数保持不变的情况下,max_depth为2的时候,均方误差MSE最小,该回归树在数据集上的拟合性也越好;因为在回归树中,我们追求的是均方误差MSE越小越好。


不积跬步,无以至千里;

不积小流,无以成江海;

参考资料:

https://scikit-learn.org/stable/modules/tree.html

相关推荐

面试官:来,讲一下枚举类型在开发时中实际应用场景!

一.基本介绍枚举是JDK1.5新增的数据类型,使用枚举我们可以很好的描述一些特定的业务场景,比如一年中的春、夏、秋、冬,还有每周的周一到周天,还有各种颜色,以及可以用它来描述一些状态信息,比如错...

一日一技:11个基本Python技巧和窍门

1.两个数字的交换.x,y=10,20print(x,y)x,y=y,xprint(x,y)输出:102020102.Python字符串取反a="Ge...

Python Enum 技巧,让代码更简洁、更安全、更易维护

如果你是一名Python开发人员,你很可能使用过enum.Enum来创建可读性和可维护性代码。今天发现一个强大的技巧,可以让Enum的境界更进一层,这个技巧不仅能提高可读性,还能以最小的代价增...

Python元组编程指导教程(python元组的概念)

1.元组基础概念1.1什么是元组元组(Tuple)是Python中一种不可变的序列类型,用于存储多个有序的元素。元组与列表(list)类似,但元组一旦创建就不能修改(不可变),这使得元组在某些场景...

你可能不知道的实用 Python 功能(python有哪些用)

1.超越文件处理的内容管理器大多数开发人员都熟悉使用with语句进行文件操作:withopen('file.txt','r')asfile:co...

Python 2至3.13新特性总结(python 3.10新特性)

以下是Python2到Python3.13的主要新特性总结,按版本分类整理:Python2到Python3的重大变化Python3是一个不向后兼容的版本,主要改进包括:pri...

Python中for循环访问索引值的方法

技术背景在Python编程中,我们经常需要在循环中访问元素的索引值。例如,在处理列表、元组等可迭代对象时,除了要获取元素本身,还需要知道元素的位置。Python提供了多种方式来实现这一需求,下面将详细...

Python enumerate核心应用解析:索引遍历的高效实践方案

喜欢的条友记得关注、点赞、转发、收藏,你们的支持就是我最大的动力源泉。根据GitHub代码分析统计,使用enumerate替代range(len())写法可减少38%的索引错误概率。本文通过12个生产...

Python入门到脱坑经典案例—列表去重

列表去重是Python编程中常见的操作,下面我将介绍多种实现列表去重的方法,从基础到进阶,帮助初学者全面掌握这一技能。方法一:使用集合(set)去重(最简单)pythondefremove_dupl...

Python枚举类工程实践:常量管理的标准化解决方案

本文通过7个生产案例,系统解析枚举类在工程实践中的应用,覆盖状态管理、配置选项、错误代码等场景,适用于Web服务开发、自动化测试及系统集成领域。一、基础概念与语法演进1.1传统常量与枚举类对比#传...

让Python枚举更强大!教你玩转Enum扩展

为什么你需要关注Enum?在日常开发中,你是否经常遇到这样的代码?ifstatus==1:print("开始处理")elifstatus==2:pri...

Python枚举(Enum)技巧,你值得了解

枚举(Enum)提供了更清晰、结构化的方式来定义常量。通过为枚举添加行为、自动分配值和存储额外数据,可以提升代码的可读性、可维护性,并与数据库结合使用时,使用字符串代替数字能简化调试和查询。Pytho...

78行Python代码帮你复现微信撤回消息!

来源:悟空智能科技本文约700字,建议阅读5分钟。本文基于python的微信开源库itchat,教你如何收集私聊撤回的信息。[导读]Python曾经对我说:"时日不多,赶紧用Python"。于是看...

登录人人都是产品经理即可获得以下权益

文章介绍如何利用Cursor自动开发Playwright网页自动化脚本,实现从选题、写文、生图的全流程自动化,并将其打包成API供工作流调用,提高工作效率。虽然我前面文章介绍了很多AI工作流,但它们...

Python常用小知识-第二弹(python常用方法总结)

一、Python中使用JsonPath提取字典中的值JsonPath是解析Json字符串用的,如果有一个多层嵌套的复杂字典,想要根据key和下标来批量提取value,这是比较困难的,使用jsonpat...

取消回复欢迎 发表评论: