Python机器学习系列之scikit-learn决策树回归问题简单实践
off999 2024-11-26 07:24 28 浏览 0 评论
1.决策树概述
在前面几期的章节中
我们分别简单介绍了一下决策树DecisionTree的原理,以及使用一个简单的案例介绍了下决策树是如何处理简单分类问题的。通过前面两个章节简单地学习,我们再次回顾下决策树的一些优缺点,总结如下:
1.1决策树的优点是:
- 易于理解和解释,树可以被可视化;
- 几乎不需要数据准备。其他算法通常需要数据标准化,需要创建虚拟变量并删除缺失值。但是,请注意,此模块不支持缺失值。
- 使用树的成本(即预测数据)是用于训练树的数据点数的对数。
- 能够处理数值型和分类型数据。其他技术通常专门分析只有一种类型变量的数据集。
- 能够处理多输出问题。
- 使用白盒模型。如果给定的情况在模型中是可以观察到的,那么对条件的解释就很容易用布尔逻辑来解释。相反,在黑箱模型中(例如,在人工神经网络中),结果可能很难解释。
- 可以使用统计测试验证模型。这样就有可能对模型的可靠性作出解释。
- 即使它的假设在某种程度上被生成数据的真实模型所违背,它也表现得很好。
1.2决策树的缺点包括:
- 决策树学习器可以创建过于复杂的树,不能很好地概括数据。这就是所谓的过拟合。为了避免这个问题,必须设置剪枝、设置叶节点所需的最小样本数或设置树的最大深度等机制。
- 决策树可能是不稳定的,因为数据中的小变化可能导致生成完全不同的树。通过集成决策树来缓解这个问题。
- 学习最优决策树的问题在最优性的几个方面都是NP-complete的,甚至对于简单的概念也是如此。因此,实际的决策树学习算法是基于启发式算法,如贪婪算法,在每个节点上进行局部最优决策。这种算法不能保证返回全局最优决策树。这可以通过训练多棵树再集成一个学习器来缓解,其中特征和样本被随机抽取并替换。
- 有些概念很难学习,因为决策树不能很容易地表达它们,例如异或、奇偶校验或多路复用器问题。
- 如果某些类占主导地位,则决策树学习者会创建有偏见的树。因此,建议在拟合决策树之前平衡数据集。
1.3决策树相关概念
我们先来回归一下决策树的几个相关概念:
节点 | 说明 |
根节点 | 没有进边,有出边 |
中间节点 | 既有进边也有出边,但进边有且只有一条,出边也可以有很多条 |
叶节点 | 只有进边,没有出边,进边有且只有一条,每个页节点都是一个类别标签 |
父节点和字节点 | 在两个相连的节点中,更靠近根节点的是父节点,另一个则是子节点,两者是相对的 |
2.CART算法
2.1什么是CART算法?
CART是英文Classification And Regression Tree的简写,又称为分类回归树。从它的名字我们就可以看出,它是一个很强大的算法,既可以用于分类还可以用于回归,所以非常值得我们来学习。
CART算法使用的就是二元切分法,这种方法可以通过调整树的构建过程,使其能够处理连续型变量。
具体的处理方法是:如果特征值大于给定值就走左子树,否则就走右子树。
CART算法分为两步:
- 决策树生成:递归地构建二叉决策树的过程,基于训练数据集生成决策树,生成的决策树要尽量的大;自上而下从根开始构建节点,在每个节点处要选择一个最好的属性来分裂,使得字节点中的训练集尽量的纯;
- 决策树剪枝:用验证数据集对已生成的树进行剪枝并选择最优子树,这时损失函数最小作为剪枝的标准;
不同的算法使用不同的指标来定义“最好”:
算法 | 指标 |
ID3 信息增益 |
|
C4.5 信息增益比 |
|
CART 基尼系数 |
|
三种方法本质上都相同,在类别分布均衡时达到最大值,而当所有记录都属于同一个类别时达到最小值。换而言之,在纯度较高时三个指数均较低,而当纯度较低时,三个指数都比较大,且可以计算得出,熵在0-1区间内分布,而Gini指数和分类误差均在0-0.5区间内分布。
CART树的构建过程:首先找到最佳的列来切分数据集,每次都执行二元切分法,如果特征值大于给定值就走左子树,否则就走右子树,当节点不能再分时就将该节点保存为叶节点。
3.回归树的sklearn实现
3.1 DecisionTreeRegressor
class sklearn.tree.DecisionTreeRegressor (criterion='mse', splitter='best', max_depth=None,
min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None,
random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, presort=False)
几乎所有参数,属性及接口都和分类树一模一样。需要注意的是,在回归树种,没有标签分布是否均衡的问题,因此没有class_weight这样的参数。
3.2 重要参数,属性和接口
回归树衡量分枝质量的指标,支持的标准有三种:
- 输入"mse"使用均方误差mean squared error(MSE),父节点和叶子节点之间的均方误差的差额将被用来作为特征选择的标准,这种方法通过使用叶子节点的均值来最小化L2损失;
- 输入“friedman_mse”使用费尔德曼均方误差,这种指标使用弗里德曼针对潜在分枝中的问题改进后的均方误差;
- 输入"mae"使用绝对平均误差MAE(mean absolute error),这种指标使用叶节点的中值来最小化L1损失;
属性中最重要的依然是feature_importances_,接口依然是apply, fit, predict, score最核心。
在回归树中,MSE不只是我们的分枝质量衡量指标,也是我们最常用的衡量回归树回归质量的指标,当我们在使用交叉验证,或者其他方式获取回归树的结果时,我们往往选择均方误差作为我们的评估(在分类树中这个指标是score代表的预测准确率)。在回归中,我们追求的是,MSE越小越好。
然而,回归树的接口score返回的是R平方,并不是MSE。
虽然均方误差永远为正,但是sklearn当中使用均方误差作为评判标准时,却是计算”负均方误差“(neg_mean_squared_error)。这是因为sklearn在计算模型评估指标的时候,会考虑指标本身的性质,均方误差本身是一种误差,所以被sklearn划分为模型的一种损失(loss),因此在sklearn当中,都以负数表示。真正的均方误差MSE的数值,其实就是neg_mean_squared_error去掉负号的数字。
3.3 简单看看回归树是怎样工作的
接下来我们到二维平面上来观察决策树是怎样拟合一条曲线的。我们用回归树来拟合正弦曲线,并添加一些噪声来观察回归树的表现。
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
导入处理二维矩阵的numpy,回归树DecisionTreeRegressor以及画图的matplotlib等必要的模块;
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))
使用随机数函数生成带有噪声的数据集;
# 建立回归树模型并且训练
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X, y)
regr_2.fit(X, y)
# 使用回归树模型进行预测
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)
建立两个回归树模型,回归树的深度分别是2和5,分别用上面生成的数据集训练这两个回归树模型,然后用一组测试数据集分别对这两个回归树模型进行预测。
plt.figure(figsize=(16, 9))
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.savefig('DecisionTreeRegressor.jpg')
plt.show()
使用matplotlib模块将上面生成的数据集和回归树预测后的数据进行绘制到一张图像中,运行效果如下:
通过上面运行效果图我们可以清楚地发现,回归树学习了近似正弦曲线的局部线性回归。我们可以看到,如果树的最大深度(由max_depth参数控制)设置得太高,则决策树学习得太精细,它从训练数据中学了很多细节,包括噪声得呈现,从而使模型偏离真实的正弦曲线,形成过拟合。
3.4简单探索下回归树DecisionTreeRegressor模型DecisionTreeRegressor
我们使用sklearn自带的交叉验证函数来看看我们上面建立的两棵回归树的均方误差是怎样的?
from sklearn.model_selection import cross_val_score
val_score1 = cross_val_score(regr_1, X, y, cv=10, scoring = "neg_mean_squared_error")
val_score1
val_score2 = cross_val_score(regr_2, X, y, cv=10, scoring = "neg_mean_squared_error")
val_score2
我们分别对回归树模型regr_1和regr_2进行10次交叉验证,验证指标为均方误差,该函数返回一个包含10个浮点数的数组,数组里面的每一个浮点数表示一次验证的均方误差值。我们再对10次计算的均方误差进行一次求均值运输,就能够得到一个浮点数。运行效果如下:
交叉验证是用来观察模型的稳定性的一种方法,我们将数据划分为n份,依次使用其中一份作为测试集,其他n-1份作为训练集,多次计算模型的精确性来评估模型的平均准确程度。训练集和测试集的划分会干扰模型的结果,因此用交叉验证n次的结果求出的平均值,是对模型效果的一个更好的度量。
下面我们通过画一组学习曲线,来看看上面构建的回归树DecisionTreeRegressor和该回归树在上面我们生成的数据集训练后,max_depth与均方误差的关系
scores=[]
for i in range(10):
clf = DecisionTreeRegressor(max_depth=i+1,
random_state=10)
val_score = -cross_val_score(clf, X, y, cv=10, scoring = "neg_mean_squared_error").mean()
scores.append(val_score)
我们经过10次循环,分别建立10棵max_depth不一样的回归树模型,每一颗回归树进行10次交叉验证,并且将交叉验证的结果进行求均值。
plt.plot(range(1,11),scores,color="red",label="max_depth")
plt.ylabel("neg_mean_squared_error")
plt.grid()
plt.legend()
plt.savefig('Regressor_max_depth.jpg')
plt.show()
使用matplotlib模块绘制曲线,上面代码块运行效果如下图
通过上图我们可以清晰的知道,当回归树其他参数保持不变的情况下,max_depth为2的时候,均方误差MSE最小,该回归树在数据集上的拟合性也越好;因为在回归树中,我们追求的是均方误差MSE越小越好。
不积跬步,无以至千里;
不积小流,无以成江海;
参考资料:
https://scikit-learn.org/stable/modules/tree.html
相关推荐
- PYTHON-简易计算器的元素介绍
-
[烟花]了解模板代码的组成importPySimpleGUIassg#1)导入库layout=[[],[],[]]#2)定义布局,确定行数window=sg.Window(...
- 如何使用Python编写一个简单的计算器程序
-
Python是一种简单易学的编程语言,非常适合初学者入门。本文将教您如何使用Python编写一个简单易用的计算器程序,帮助您快速进行基本的数学运算。无需任何高深的数学知识,只需跟随本文的步骤,即可轻松...
- 用Python打造一个简洁美观的桌面计算器
-
最近在学习PythonGUI编程,顺手用Tkinter实现了一个简易桌面计算器,功能虽然不复杂,但非常适合新手练手。如果你正在学习Python,不妨一起来看看这个项目吧!项目背景Tkint...
- 用Python制作一个带图形界面的计算器
-
大家好,今天我要带大家使用Python制作一个具有图形界面的计算器应用程序。这个项目不仅可以帮助你巩固Python编程基础,还可以让你初步体验图形化编程的乐趣。我们将使用Python的tkinter库...
- 用python怎么做最简单的桌面计算器
-
有网友问,用python怎么做一个最简单的桌面计算器。如果只强调简单,在本机运行,不考虑安全性和容错等的话,你能想到的最简单的方案是什么呢?我觉得用tkinter加eval就够简单的。现在开整。首先创...
- 说好的《Think Python 2e》更新呢!
-
编程派微信号:codingpy本周三脱更了,不过发现好多朋友在那天去访问《ThinkPython2e》的在线版,感觉有点对不住呢(实在是没抽出时间来更新)。不过还好本周六的更新可以实现,要不就放一...
- 构建AI系统(三):使用Python设置您的第一个MCP服务器
-
是时候动手实践了!在这一部分中,我们将设置开发环境并创建我们的第一个MCP服务器。如果您从未编写过代码,也不用担心-我们将一步一步来。我们要构建什么还记得第1部分中Maria的咖啡馆吗?我们正在创...
- 函数还是类?90%程序员都踩过的Python认知误区
-
那个深夜,你在调试代码,一行行检查变量类型。突然,一个TypeError错误蹦出来,你盯着那句"strobjectisnotcallable",咖啡杯在桌上留下了一圈深色...
- 《Think Python 2e》中译版更新啦!
-
【回复“python”,送你十本电子书】又到了周三,一周快过去一半了。小编按计划更新《ThinkPython2e》最新版中译。今天更新的是第五章:条件和递归。具体内容请点击阅读原文查看。其他章节的...
- Python mysql批量更新数据(兼容动态数据库字段、表名)
-
一、应用场景上篇文章我们学会了在pymysql事务中批量插入数据的复用代码,既然有了批量插入,那批量更新和批量删除的操作也少不了。二、解决思路为了解决批量删除和批量更新的问题,提出如下思路:所有更新语...
- Python Pandas 库:解锁 combine、update 和compare函数的强大功能
-
在Python的数据处理领域,Pandas库提供了丰富且实用的函数,帮助我们高效地处理和分析数据。今天,咱们就来深入探索Pandas库中四个功能独特的函数:combine、combine_fi...
- 记录Python3.7.4更新到Python.3.7.8
-
Python官网Python安装包下载下载文件名称运行后选择升级选项等待安装安装完毕打开IDLE使用Python...
- Python千叶网原图爬虫:界面化升级实践
-
该工具以Python爬虫技术为核心,实现千叶网原图的精准抓取,突破缩略图限制,直达高清资源。新增图形化界面(GUI)后,操作门槛大幅降低:-界面集成URL输入、存储路径选择、线程设置等核心功能,...
- __future__模块:Python语言版本演进的桥梁
-
摘要Python作为一门持续演进的编程语言,在版本迭代过程中不可避免地引入了破坏性变更。__future__模块作为Python兼容性管理的核心机制,为开发者提供了在旧版本中体验新特性的能力。本文深入...
- Python 集合隐藏技能:add 与 update 的致命区别,90% 开发者都踩过坑
-
add函数的使用场景及错误注意添加单一元素:正确示例:pythons={1,2}s.add(3)print(s)#{1,2,3}错误场景:试图添加可变对象(如列表)会报错(Pytho...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- python计时 (73)
- python安装路径 (56)
- python类型转换 (93)
- python进度条 (67)
- python吧 (67)
- python的for循环 (65)
- python格式化字符串 (61)
- python静态方法 (57)
- python列表切片 (59)
- python面向对象编程 (60)
- python 代码加密 (65)
- python串口编程 (77)
- python封装 (57)
- python读取文件夹下所有文件 (59)
- java调用python脚本 (56)
- python操作mysql数据库 (66)
- python获取列表的长度 (64)
- python接口 (63)
- python调用函数 (57)
- python多态 (60)
- python匿名函数 (59)
- python打印九九乘法表 (65)
- python赋值 (62)
- python异常 (69)
- python元祖 (57)