Python机器学习系列之scikit-learn决策树回归问题简单实践
off999 2024-11-26 07:24 33 浏览 0 评论
1.决策树概述
在前面几期的章节中
我们分别简单介绍了一下决策树DecisionTree的原理,以及使用一个简单的案例介绍了下决策树是如何处理简单分类问题的。通过前面两个章节简单地学习,我们再次回顾下决策树的一些优缺点,总结如下:
1.1决策树的优点是:
- 易于理解和解释,树可以被可视化;
- 几乎不需要数据准备。其他算法通常需要数据标准化,需要创建虚拟变量并删除缺失值。但是,请注意,此模块不支持缺失值。
- 使用树的成本(即预测数据)是用于训练树的数据点数的对数。
- 能够处理数值型和分类型数据。其他技术通常专门分析只有一种类型变量的数据集。
- 能够处理多输出问题。
- 使用白盒模型。如果给定的情况在模型中是可以观察到的,那么对条件的解释就很容易用布尔逻辑来解释。相反,在黑箱模型中(例如,在人工神经网络中),结果可能很难解释。
- 可以使用统计测试验证模型。这样就有可能对模型的可靠性作出解释。
- 即使它的假设在某种程度上被生成数据的真实模型所违背,它也表现得很好。
1.2决策树的缺点包括:
- 决策树学习器可以创建过于复杂的树,不能很好地概括数据。这就是所谓的过拟合。为了避免这个问题,必须设置剪枝、设置叶节点所需的最小样本数或设置树的最大深度等机制。
- 决策树可能是不稳定的,因为数据中的小变化可能导致生成完全不同的树。通过集成决策树来缓解这个问题。
- 学习最优决策树的问题在最优性的几个方面都是NP-complete的,甚至对于简单的概念也是如此。因此,实际的决策树学习算法是基于启发式算法,如贪婪算法,在每个节点上进行局部最优决策。这种算法不能保证返回全局最优决策树。这可以通过训练多棵树再集成一个学习器来缓解,其中特征和样本被随机抽取并替换。
- 有些概念很难学习,因为决策树不能很容易地表达它们,例如异或、奇偶校验或多路复用器问题。
- 如果某些类占主导地位,则决策树学习者会创建有偏见的树。因此,建议在拟合决策树之前平衡数据集。
1.3决策树相关概念
我们先来回归一下决策树的几个相关概念:
节点 | 说明 |
根节点 | 没有进边,有出边 |
中间节点 | 既有进边也有出边,但进边有且只有一条,出边也可以有很多条 |
叶节点 | 只有进边,没有出边,进边有且只有一条,每个页节点都是一个类别标签 |
父节点和字节点 | 在两个相连的节点中,更靠近根节点的是父节点,另一个则是子节点,两者是相对的 |
2.CART算法
2.1什么是CART算法?
CART是英文Classification And Regression Tree的简写,又称为分类回归树。从它的名字我们就可以看出,它是一个很强大的算法,既可以用于分类还可以用于回归,所以非常值得我们来学习。
CART算法使用的就是二元切分法,这种方法可以通过调整树的构建过程,使其能够处理连续型变量。
具体的处理方法是:如果特征值大于给定值就走左子树,否则就走右子树。
CART算法分为两步:
- 决策树生成:递归地构建二叉决策树的过程,基于训练数据集生成决策树,生成的决策树要尽量的大;自上而下从根开始构建节点,在每个节点处要选择一个最好的属性来分裂,使得字节点中的训练集尽量的纯;
- 决策树剪枝:用验证数据集对已生成的树进行剪枝并选择最优子树,这时损失函数最小作为剪枝的标准;
不同的算法使用不同的指标来定义“最好”:
算法 | 指标 |
ID3 信息增益 |
|
C4.5 信息增益比 |
|
CART 基尼系数 |
|
三种方法本质上都相同,在类别分布均衡时达到最大值,而当所有记录都属于同一个类别时达到最小值。换而言之,在纯度较高时三个指数均较低,而当纯度较低时,三个指数都比较大,且可以计算得出,熵在0-1区间内分布,而Gini指数和分类误差均在0-0.5区间内分布。
CART树的构建过程:首先找到最佳的列来切分数据集,每次都执行二元切分法,如果特征值大于给定值就走左子树,否则就走右子树,当节点不能再分时就将该节点保存为叶节点。
3.回归树的sklearn实现
3.1 DecisionTreeRegressor
class sklearn.tree.DecisionTreeRegressor (criterion='mse', splitter='best', max_depth=None,
min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None,
random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, presort=False)
几乎所有参数,属性及接口都和分类树一模一样。需要注意的是,在回归树种,没有标签分布是否均衡的问题,因此没有class_weight这样的参数。
3.2 重要参数,属性和接口
回归树衡量分枝质量的指标,支持的标准有三种:
- 输入"mse"使用均方误差mean squared error(MSE),父节点和叶子节点之间的均方误差的差额将被用来作为特征选择的标准,这种方法通过使用叶子节点的均值来最小化L2损失;
- 输入“friedman_mse”使用费尔德曼均方误差,这种指标使用弗里德曼针对潜在分枝中的问题改进后的均方误差;
- 输入"mae"使用绝对平均误差MAE(mean absolute error),这种指标使用叶节点的中值来最小化L1损失;
属性中最重要的依然是feature_importances_,接口依然是apply, fit, predict, score最核心。
在回归树中,MSE不只是我们的分枝质量衡量指标,也是我们最常用的衡量回归树回归质量的指标,当我们在使用交叉验证,或者其他方式获取回归树的结果时,我们往往选择均方误差作为我们的评估(在分类树中这个指标是score代表的预测准确率)。在回归中,我们追求的是,MSE越小越好。
然而,回归树的接口score返回的是R平方,并不是MSE。
虽然均方误差永远为正,但是sklearn当中使用均方误差作为评判标准时,却是计算”负均方误差“(neg_mean_squared_error)。这是因为sklearn在计算模型评估指标的时候,会考虑指标本身的性质,均方误差本身是一种误差,所以被sklearn划分为模型的一种损失(loss),因此在sklearn当中,都以负数表示。真正的均方误差MSE的数值,其实就是neg_mean_squared_error去掉负号的数字。
3.3 简单看看回归树是怎样工作的
接下来我们到二维平面上来观察决策树是怎样拟合一条曲线的。我们用回归树来拟合正弦曲线,并添加一些噪声来观察回归树的表现。
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
导入处理二维矩阵的numpy,回归树DecisionTreeRegressor以及画图的matplotlib等必要的模块;
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))
使用随机数函数生成带有噪声的数据集;
# 建立回归树模型并且训练
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X, y)
regr_2.fit(X, y)
# 使用回归树模型进行预测
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)
建立两个回归树模型,回归树的深度分别是2和5,分别用上面生成的数据集训练这两个回归树模型,然后用一组测试数据集分别对这两个回归树模型进行预测。
plt.figure(figsize=(16, 9))
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.savefig('DecisionTreeRegressor.jpg')
plt.show()
使用matplotlib模块将上面生成的数据集和回归树预测后的数据进行绘制到一张图像中,运行效果如下:
通过上面运行效果图我们可以清楚地发现,回归树学习了近似正弦曲线的局部线性回归。我们可以看到,如果树的最大深度(由max_depth参数控制)设置得太高,则决策树学习得太精细,它从训练数据中学了很多细节,包括噪声得呈现,从而使模型偏离真实的正弦曲线,形成过拟合。
3.4简单探索下回归树DecisionTreeRegressor模型DecisionTreeRegressor
我们使用sklearn自带的交叉验证函数来看看我们上面建立的两棵回归树的均方误差是怎样的?
from sklearn.model_selection import cross_val_score
val_score1 = cross_val_score(regr_1, X, y, cv=10, scoring = "neg_mean_squared_error")
val_score1
val_score2 = cross_val_score(regr_2, X, y, cv=10, scoring = "neg_mean_squared_error")
val_score2
我们分别对回归树模型regr_1和regr_2进行10次交叉验证,验证指标为均方误差,该函数返回一个包含10个浮点数的数组,数组里面的每一个浮点数表示一次验证的均方误差值。我们再对10次计算的均方误差进行一次求均值运输,就能够得到一个浮点数。运行效果如下:
交叉验证是用来观察模型的稳定性的一种方法,我们将数据划分为n份,依次使用其中一份作为测试集,其他n-1份作为训练集,多次计算模型的精确性来评估模型的平均准确程度。训练集和测试集的划分会干扰模型的结果,因此用交叉验证n次的结果求出的平均值,是对模型效果的一个更好的度量。
下面我们通过画一组学习曲线,来看看上面构建的回归树DecisionTreeRegressor和该回归树在上面我们生成的数据集训练后,max_depth与均方误差的关系
scores=[]
for i in range(10):
clf = DecisionTreeRegressor(max_depth=i+1,
random_state=10)
val_score = -cross_val_score(clf, X, y, cv=10, scoring = "neg_mean_squared_error").mean()
scores.append(val_score)
我们经过10次循环,分别建立10棵max_depth不一样的回归树模型,每一颗回归树进行10次交叉验证,并且将交叉验证的结果进行求均值。
plt.plot(range(1,11),scores,color="red",label="max_depth")
plt.ylabel("neg_mean_squared_error")
plt.grid()
plt.legend()
plt.savefig('Regressor_max_depth.jpg')
plt.show()
使用matplotlib模块绘制曲线,上面代码块运行效果如下图
通过上图我们可以清晰的知道,当回归树其他参数保持不变的情况下,max_depth为2的时候,均方误差MSE最小,该回归树在数据集上的拟合性也越好;因为在回归树中,我们追求的是均方误差MSE越小越好。
不积跬步,无以至千里;
不积小流,无以成江海;
参考资料:
https://scikit-learn.org/stable/modules/tree.html
相关推荐
- apisix动态修改路由的原理_动态路由协议rip的配置
-
ApacheAPISIX能够实现动态修改路由(DynamicRouting)的核心原理,是它将传统的静态Nginx配置彻底解耦,通过中心化配置存储(如etcd)+OpenRest...
- 使用 Docker 部署 OpenResty Manager 搭建可视化反向代理系统
-
在之前的文章中,xiaoz推荐过可视化Nginx反向代理工具NginxProxyManager,最近xiaoz还发现一款功能更加强大,界面更加漂亮的OpenRestyManager,完全可以替代...
- OpenResty 入门指南:从基础到动态路由实战
-
一、引言1.1OpenResty简介OpenResty是一款基于Nginx的高性能Web平台,通过集成Lua脚本和丰富的模块,将Nginx从静态反向代理转变为可动态编程的应用平台...
- OpenResty 的 Lua 动态能力_openresty 动态upstream
-
OpenResty的Lua动态能力是其最核心的优势,它将LuaJIT嵌入到Nginx的每一个请求处理阶段,使得开发者可以用Lua脚本动态控制请求的生命周期,而无需重新编译或rel...
- LVS和Nginx_lvs和nginx的区别
-
LVS(LinuxVirtualServer)和Nginx都是常用的负载均衡解决方案,广泛应用于大型网站和分布式系统中,以提高系统的性能、可用性和可扩展性。一、基本概念1.LVS(Linux...
- 外网连接到内网服务器需要端口映射吗,如何操作?
-
外网访问内网服务器通常需要端口映射(或内网穿透),这是跨越公网与私网边界的关键技术。操作方式取决于网络环境,以下分场景详解。一、端口映射的核心原理内网服务器位于私有IP地址段(如192.168.x.x...
- Nginx如何解决C10K问题(1万个并发连接)?
-
关注△mikechen△,十余年BAT架构经验倾囊相授!大家好,我是mikechen。Nginx是大型架构的必备中间件,下面我就全面来详解NginxC10k问题@mikechen文章来源:mikec...
- 炸场!Spring Boot 9 大内置过滤器实战手册:从坑到神
-
炸场!SpringBoot9大内置过滤器实战手册:从坑到神在Java开发圈摸爬滚打十年,见过太多团队重复造轮子——明明SpringBoot自带的过滤器就能解决的问题,偏偏要手写几十...
- WordPress和Typecho xmlrpc漏洞_wordpress主题漏洞
-
一般大家都关注WordPress,毕竟用户量巨大,而国内的Typecho作为轻量级的博客系统就关注的人并不多。Typecho有很多借鉴WordPress的,包括兼容的xmlrpc接口,而WordPre...
- Linux Shell 入门教程(六):重定向、管道与命令替换
-
在前几篇中,我们学习了函数、流程控制等Shell编程的基础内容。现在我们来探索更高级的功能:如何控制数据流向、将命令链接在一起、让命令间通信变得可能。一、输入输出重定向(>、>>...
- Nginx的location匹配规则,90%的人都没完全搞懂,一张图让你秒懂
-
刚配完nginx网站就崩了?运维和开发都头疼的location匹配规则优先级,弄错顺序直接导致500错误。核心在于nginx处理location时顺序严格:先精确匹配=,然后前缀匹配^~,接着按顺序正...
- liunx服务器查看故障命令有那些?_linux查看服务器性能命令
-
在Linux服务器上排查故障时,需要使用一系列命令来检查系统状态、日志文件、资源利用情况以及网络状况。以下是常用的故障排查命令,按照不同场景分类说明。1.系统资源相关命令1.1查看CPU使...
- 服务器被入侵的常见迹象有哪些?_服务器入侵可以被完全操纵吗
-
服务器被入侵可能会导致数据泄露、服务异常或完全失控。及时发现入侵迹象能够帮助你尽早采取措施,减少损失。以下是服务器被入侵的常见迹象以及相关的分析与处理建议。1.服务器被入侵的常见迹象1.1系统性能...
- 前端错误可观测最佳实践_前端错误提示
-
场景解析对于前端项目,生产环境的代码通常经过压缩、混淆和打包处理,当代码在运行过程中产生错误时,通常难以还原原始代码从而定位问题,对于深度混淆尤其如此,因此Mozilla自2011年开始发起并...
- 8个能让你的Kubernetes集群“瞬间崩溃”的配置错误
-
错误一:livenessProbe探针“自杀式”配置——30秒内让Pod重启20次现象:Pod状态在Running→Terminating→CrashLoopBackOff之间循环,重启间隔仅...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- apisix动态修改路由的原理_动态路由协议rip的配置
- 使用 Docker 部署 OpenResty Manager 搭建可视化反向代理系统
- OpenResty 入门指南:从基础到动态路由实战
- OpenResty 的 Lua 动态能力_openresty 动态upstream
- LVS和Nginx_lvs和nginx的区别
- 外网连接到内网服务器需要端口映射吗,如何操作?
- Nginx如何解决C10K问题(1万个并发连接)?
- 炸场!Spring Boot 9 大内置过滤器实战手册:从坑到神
- WordPress和Typecho xmlrpc漏洞_wordpress主题漏洞
- Linux Shell 入门教程(六):重定向、管道与命令替换
- 标签列表
-
- python计时 (73)
- python安装路径 (56)
- python类型转换 (93)
- python进度条 (67)
- python吧 (67)
- python的for循环 (65)
- python格式化字符串 (61)
- python静态方法 (57)
- python列表切片 (59)
- python面向对象编程 (60)
- python 代码加密 (65)
- python串口编程 (77)
- python封装 (57)
- python写入txt (66)
- python读取文件夹下所有文件 (59)
- python操作mysql数据库 (66)
- python获取列表的长度 (64)
- python接口 (63)
- python调用函数 (57)
- python多态 (60)
- python匿名函数 (59)
- python打印九九乘法表 (65)
- python赋值 (62)
- python异常 (69)
- python元祖 (57)