Python机器学习系列之scikit-learn决策树回归问题简单实践
off999 2024-11-26 07:24 42 浏览 0 评论
1.决策树概述
在前面几期的章节中
我们分别简单介绍了一下决策树DecisionTree的原理,以及使用一个简单的案例介绍了下决策树是如何处理简单分类问题的。通过前面两个章节简单地学习,我们再次回顾下决策树的一些优缺点,总结如下:
1.1决策树的优点是:
- 易于理解和解释,树可以被可视化;
- 几乎不需要数据准备。其他算法通常需要数据标准化,需要创建虚拟变量并删除缺失值。但是,请注意,此模块不支持缺失值。
- 使用树的成本(即预测数据)是用于训练树的数据点数的对数。
- 能够处理数值型和分类型数据。其他技术通常专门分析只有一种类型变量的数据集。
- 能够处理多输出问题。
- 使用白盒模型。如果给定的情况在模型中是可以观察到的,那么对条件的解释就很容易用布尔逻辑来解释。相反,在黑箱模型中(例如,在人工神经网络中),结果可能很难解释。
- 可以使用统计测试验证模型。这样就有可能对模型的可靠性作出解释。
- 即使它的假设在某种程度上被生成数据的真实模型所违背,它也表现得很好。
1.2决策树的缺点包括:
- 决策树学习器可以创建过于复杂的树,不能很好地概括数据。这就是所谓的过拟合。为了避免这个问题,必须设置剪枝、设置叶节点所需的最小样本数或设置树的最大深度等机制。
- 决策树可能是不稳定的,因为数据中的小变化可能导致生成完全不同的树。通过集成决策树来缓解这个问题。
- 学习最优决策树的问题在最优性的几个方面都是NP-complete的,甚至对于简单的概念也是如此。因此,实际的决策树学习算法是基于启发式算法,如贪婪算法,在每个节点上进行局部最优决策。这种算法不能保证返回全局最优决策树。这可以通过训练多棵树再集成一个学习器来缓解,其中特征和样本被随机抽取并替换。
- 有些概念很难学习,因为决策树不能很容易地表达它们,例如异或、奇偶校验或多路复用器问题。
- 如果某些类占主导地位,则决策树学习者会创建有偏见的树。因此,建议在拟合决策树之前平衡数据集。
1.3决策树相关概念
我们先来回归一下决策树的几个相关概念:
节点 | 说明 |
根节点 | 没有进边,有出边 |
中间节点 | 既有进边也有出边,但进边有且只有一条,出边也可以有很多条 |
叶节点 | 只有进边,没有出边,进边有且只有一条,每个页节点都是一个类别标签 |
父节点和字节点 | 在两个相连的节点中,更靠近根节点的是父节点,另一个则是子节点,两者是相对的 |
2.CART算法
2.1什么是CART算法?
CART是英文Classification And Regression Tree的简写,又称为分类回归树。从它的名字我们就可以看出,它是一个很强大的算法,既可以用于分类还可以用于回归,所以非常值得我们来学习。
CART算法使用的就是二元切分法,这种方法可以通过调整树的构建过程,使其能够处理连续型变量。
具体的处理方法是:如果特征值大于给定值就走左子树,否则就走右子树。
CART算法分为两步:
- 决策树生成:递归地构建二叉决策树的过程,基于训练数据集生成决策树,生成的决策树要尽量的大;自上而下从根开始构建节点,在每个节点处要选择一个最好的属性来分裂,使得字节点中的训练集尽量的纯;
- 决策树剪枝:用验证数据集对已生成的树进行剪枝并选择最优子树,这时损失函数最小作为剪枝的标准;
不同的算法使用不同的指标来定义“最好”:
算法 | 指标 |
ID3 信息增益 |
|
C4.5 信息增益比 |
|
CART 基尼系数 |
|
三种方法本质上都相同,在类别分布均衡时达到最大值,而当所有记录都属于同一个类别时达到最小值。换而言之,在纯度较高时三个指数均较低,而当纯度较低时,三个指数都比较大,且可以计算得出,熵在0-1区间内分布,而Gini指数和分类误差均在0-0.5区间内分布。
CART树的构建过程:首先找到最佳的列来切分数据集,每次都执行二元切分法,如果特征值大于给定值就走左子树,否则就走右子树,当节点不能再分时就将该节点保存为叶节点。
3.回归树的sklearn实现
3.1 DecisionTreeRegressor
class sklearn.tree.DecisionTreeRegressor (criterion='mse', splitter='best', max_depth=None,
min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None,
random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, presort=False)几乎所有参数,属性及接口都和分类树一模一样。需要注意的是,在回归树种,没有标签分布是否均衡的问题,因此没有class_weight这样的参数。
3.2 重要参数,属性和接口
回归树衡量分枝质量的指标,支持的标准有三种:
- 输入"mse"使用均方误差mean squared error(MSE),父节点和叶子节点之间的均方误差的差额将被用来作为特征选择的标准,这种方法通过使用叶子节点的均值来最小化L2损失;
- 输入“friedman_mse”使用费尔德曼均方误差,这种指标使用弗里德曼针对潜在分枝中的问题改进后的均方误差;
- 输入"mae"使用绝对平均误差MAE(mean absolute error),这种指标使用叶节点的中值来最小化L1损失;
属性中最重要的依然是feature_importances_,接口依然是apply, fit, predict, score最核心。
在回归树中,MSE不只是我们的分枝质量衡量指标,也是我们最常用的衡量回归树回归质量的指标,当我们在使用交叉验证,或者其他方式获取回归树的结果时,我们往往选择均方误差作为我们的评估(在分类树中这个指标是score代表的预测准确率)。在回归中,我们追求的是,MSE越小越好。
然而,回归树的接口score返回的是R平方,并不是MSE。
虽然均方误差永远为正,但是sklearn当中使用均方误差作为评判标准时,却是计算”负均方误差“(neg_mean_squared_error)。这是因为sklearn在计算模型评估指标的时候,会考虑指标本身的性质,均方误差本身是一种误差,所以被sklearn划分为模型的一种损失(loss),因此在sklearn当中,都以负数表示。真正的均方误差MSE的数值,其实就是neg_mean_squared_error去掉负号的数字。
3.3 简单看看回归树是怎样工作的
接下来我们到二维平面上来观察决策树是怎样拟合一条曲线的。我们用回归树来拟合正弦曲线,并添加一些噪声来观察回归树的表现。
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt导入处理二维矩阵的numpy,回归树DecisionTreeRegressor以及画图的matplotlib等必要的模块;
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))使用随机数函数生成带有噪声的数据集;
# 建立回归树模型并且训练
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X, y)
regr_2.fit(X, y)
# 使用回归树模型进行预测
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)建立两个回归树模型,回归树的深度分别是2和5,分别用上面生成的数据集训练这两个回归树模型,然后用一组测试数据集分别对这两个回归树模型进行预测。
plt.figure(figsize=(16, 9))
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.savefig('DecisionTreeRegressor.jpg')
plt.show()使用matplotlib模块将上面生成的数据集和回归树预测后的数据进行绘制到一张图像中,运行效果如下:
通过上面运行效果图我们可以清楚地发现,回归树学习了近似正弦曲线的局部线性回归。我们可以看到,如果树的最大深度(由max_depth参数控制)设置得太高,则决策树学习得太精细,它从训练数据中学了很多细节,包括噪声得呈现,从而使模型偏离真实的正弦曲线,形成过拟合。
3.4简单探索下回归树DecisionTreeRegressor模型DecisionTreeRegressor
我们使用sklearn自带的交叉验证函数来看看我们上面建立的两棵回归树的均方误差是怎样的?
from sklearn.model_selection import cross_val_score
val_score1 = cross_val_score(regr_1, X, y, cv=10, scoring = "neg_mean_squared_error")
val_score1
val_score2 = cross_val_score(regr_2, X, y, cv=10, scoring = "neg_mean_squared_error")
val_score2我们分别对回归树模型regr_1和regr_2进行10次交叉验证,验证指标为均方误差,该函数返回一个包含10个浮点数的数组,数组里面的每一个浮点数表示一次验证的均方误差值。我们再对10次计算的均方误差进行一次求均值运输,就能够得到一个浮点数。运行效果如下:
交叉验证是用来观察模型的稳定性的一种方法,我们将数据划分为n份,依次使用其中一份作为测试集,其他n-1份作为训练集,多次计算模型的精确性来评估模型的平均准确程度。训练集和测试集的划分会干扰模型的结果,因此用交叉验证n次的结果求出的平均值,是对模型效果的一个更好的度量。
下面我们通过画一组学习曲线,来看看上面构建的回归树DecisionTreeRegressor和该回归树在上面我们生成的数据集训练后,max_depth与均方误差的关系
scores=[]
for i in range(10):
clf = DecisionTreeRegressor(max_depth=i+1,
random_state=10)
val_score = -cross_val_score(clf, X, y, cv=10, scoring = "neg_mean_squared_error").mean()
scores.append(val_score)我们经过10次循环,分别建立10棵max_depth不一样的回归树模型,每一颗回归树进行10次交叉验证,并且将交叉验证的结果进行求均值。
plt.plot(range(1,11),scores,color="red",label="max_depth")
plt.ylabel("neg_mean_squared_error")
plt.grid()
plt.legend()
plt.savefig('Regressor_max_depth.jpg')
plt.show()使用matplotlib模块绘制曲线,上面代码块运行效果如下图
通过上图我们可以清晰的知道,当回归树其他参数保持不变的情况下,max_depth为2的时候,均方误差MSE最小,该回归树在数据集上的拟合性也越好;因为在回归树中,我们追求的是均方误差MSE越小越好。
不积跬步,无以至千里;
不积小流,无以成江海;
参考资料:
https://scikit-learn.org/stable/modules/tree.html
相关推荐
- directx官方下载win7(directx download)
-
点开始-----运行,输入dxdiag,回车后打开“DirectX诊断工具”窗口,进入“显示”选项卡,看一下是否启用了加速,没有的话,单击下面的“DirectX功能”项中的“启用”按钮,这样便打开了D...
- u盘视频无法播放怎么办(u盘上视频没办法播放)
-
解决办法:1.检查U盘存储格式是否为FAT32,如果不是,请将其格式化为FAT32; 2.检查U盘中视频文件是否损坏,如果有损坏文件,请尝试重新复制一份; 3.检查U盘中存储...
-
- 笔记本电脑无法正常启动怎么修复
-
1.可以解决。2.Windows未能启动可能是由于系统文件损坏、硬件故障或病毒感染等原因引起的。解决方法可以尝试使用Windows安全模式启动、修复启动、还原系统、重装系统等方法。3.如果以上方法都无法解决问题,可以考虑联系专业的电脑...
-
2025-11-16 04:03 off999
- 联想设置u盘为第一启动项(联想怎么设置u盘启动为第一启动项)
-
联想电脑设置u盘为第一启动项方法如下一、将电脑开机,开机瞬间按F2键进入bios设置界面二、在上面5个选项里找到boot选项,这里按键盘上左右键来移动三、这里利用键盘上下键选到USB选项,然后按F5/...
-
- 家用路由器哪个牌子最好信号最稳定
-
TP-LINK最好,信号最稳定。路由器是连接两个或多个网络的硬件设备,在网络间起网关的作用,是读取每一个数据包中的地址然后决定如何传送的专用智能性的网络设备。它能够理解不同的协议,例如某个局域网使用的以太网协议,因特网使用的TCP/IP协议...
-
2025-11-16 03:03 off999
- 安卓纯净版系统(安卓的纯净模式)
-
安卓系统有纯净模式的,安卓系统必须有纯净模式的,刷入纯净版系统可以去除一些预装的应用和系统自带软件,提高手机的运行速度和使用体验。但需要注意的是刷机有一定风险,请确保你已经备份好手机数据并了解安装风险...
- deepin系统怎么安装软件(deepin操作系统怎么安装软件)
-
deepin是一个基于Linux的操作系统,它默认不支持APK应用。要在deepin上安装APK应用,需要先安装一个Android模拟器,例如Anbox,然后从GooglePlayStore或其他...
-
- 下载app安装包(下载app安装包损坏)
-
1,没有刷机过的,可以在手机里面,找到系统自带的文件管理-(如图),2,点开后,可以直接看到文件分类,找到,安装包,点开,(如下图)3,即可看到手机里面的未安装APP;操作方法01如果是直接在浏览器上下载的软件,那就直接点开浏览器,然后点击...
-
2025-11-16 01:51 off999
- window7旗舰版密码忘记(win7密码忘记了怎么办旗舰版)
-
1、重启电脑按f8选择“带命令提示符的安全模式”,跳出“CommandPrompt”窗口。2、在窗口中输入“netuserasd/add”回车,再升级输入“netlocalgroupadmi...
- windows7界面(windows7界面由哪几个部分组成)
-
您好!Windows7一般有两种界面。一种为Aero界面,一种为经典界面。Aero界面还包含三个小分类:性能最佳Aero,BasicAero,对比度Aero。性能最佳Aero是Windows7最...
- wps截图快捷键(WPS截图快捷键是哪个)
-
在WPS中进行截屏,可以通过快捷键来实现。具体操作在按下“Alt+PrtSc”之后,就会将当前屏幕截图保存到剪贴板中。若需要将截图保存为图片文件,则在粘贴时选择“文件夹”而不是“粘贴”,再选定存储...
- 电脑主机自动关机是什么原因
-
原因一、软件 1.病毒破坏,自从有了计算机以后不久,计算机病毒也应运而生。当网络成为当今社会的信息大动脉后,病毒的传播更加方便,所以也时不时的干扰和破坏我们的正常工作。比较典型的就是前一段时间对...
- 显示桌面快捷键(怎么设置桌面快捷图标)
-
电脑上显示桌面的快捷键如下:1,常用。同时按Win徽标键+D键(win键位于Ctrl与Alt之间像个飘起来的田字):按一次显示桌面,再同时按一次返回到窗口。2,同时按Win徽标键+M:原本含义是“...
- 如何使用u盘拷贝文件(如何使用u盘拷贝文件到电脑)
-
1、插入u盘,在桌面上或“我的电脑”中能查看u盘信息。2、在电脑中找到需要拷贝的文件,右键点击复制。3、进入u盘界面,在空白处点击右键,选择“粘贴”即可拷贝到u盘。或者,同时打开需要复制的文件窗口和u...
欢迎 你 发表评论:
- 一周热门
-
-
抖音上好看的小姐姐,Python给你都下载了
-
全网最简单易懂!495页Python漫画教程,高清PDF版免费下载
-
Python 3.14 的 UUIDv6/v7/v8 上新,别再用 uuid4 () 啦!
-
python入门到脱坑 输入与输出—str()函数
-
飞牛NAS部署TVGate Docker项目,实现内网一键转发、代理、jx
-
宝塔面板如何添加免费waf防火墙?(宝塔面板开启https)
-
Python三目运算基础与进阶_python三目运算符判断三个变量
-
(新版)Python 分布式爬虫与 JS 逆向进阶实战吾爱分享
-
慕ke 前端工程师2024「完整」
-
失业程序员复习python笔记——条件与循环
-
- 最近发表
- 标签列表
-
- python计时 (73)
- python安装路径 (56)
- python类型转换 (93)
- python进度条 (67)
- python吧 (67)
- python的for循环 (65)
- python格式化字符串 (61)
- python静态方法 (57)
- python列表切片 (59)
- python面向对象编程 (60)
- python 代码加密 (65)
- python串口编程 (77)
- python封装 (57)
- python写入txt (66)
- python读取文件夹下所有文件 (59)
- python操作mysql数据库 (66)
- python获取列表的长度 (64)
- python接口 (63)
- python调用函数 (57)
- python多态 (60)
- python匿名函数 (59)
- python打印九九乘法表 (65)
- python赋值 (62)
- python异常 (69)
- python元祖 (57)
