Python机器学习系列之scikit-learn决策树分类问题简单实践
off999 2024-11-26 07:24 25 浏览 0 评论
1.决策树分类问题实践
在前面的一个章节中,我们简要地概述了一下决策树的原理知识,了解了一下决策树分支原理,调参过程以及可视化决策树:
在决策树DecisionTree中,DecisionTreeClassifier是能够处理一些分类问题的。与其他分类器一样,DecisionTreeClassifier将两个数组作为输入:一个数组X,稀疏或者密集,shape (n_samples, n_features),以及一个整数值数组Y,shape (n_samples,)
from sklearn import tree
X = [[0, 0], [1, 1]]
Y = [0, 1]
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)
拟合后,该模型可用于预测样本类别:
clf.predict([[2., 2.]])
array([1])
如果存在多个具有相同且最高概率的类,分类器将预测这些类中具有最低索引的类。
作为输出特点类的替代方法,可以预测每个类的概率,即该类在叶子节点中的训练样本的分数:
clf.predict_proba([[2., 2.]])
array([[0., 1.]])
DecisionTreeClassifier能够进行二元(其中标签为[-1,1])分类和多分类(其中标签为[0,.....,k-1])分类。
下面通过一个简单的官方示例来了解下,绘制根据iris数据集的一对特征进行训练的决策树的决策面。
对于每一对iris数据集特征,决策树学习由训练样本推断的简单阈值规则组合构成的决策边界。
1.1 加载数据集
首先加载scikit-learn附带的iris数据集样本:
from sklearn.datasets import load_iris
iris = load_iris()
通过数据集的属性简单探索下iris数据集的结构:
iris.data.shape
(150, 4)
通过shape属性我们可以清楚地看到,iris数据集是一个二维矩阵,数据集中有150个样本,每一个样本包含4个特征属性;
iris.target.shape
(150,)
其实,标签数据是一个一维矩阵,标签中也有150个样本;从这里能够充分说明iris数据集中,特征的样本数与标签的样本数是一样的,都是150个样本。
import numpy as np
np.unique(iris.target)
array([0, 1, 2])
iris.target_names
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
通过以上代码,我们可以验证标签数据集中包含3个类别,分别是:setosa,versicolor和virginica。
import pandas as pd
csv = pd.concat([pd.DataFrame(iris.data, columns=iris.feature_names), pd.DataFrame(iris.target, columns=['target'])], axis=1)
csv
csv.head()
通过pandas将数据集转化成csv格式,调用head()函数,我们可以清楚地看到iris数据集的全貌,经过简单的探索,我们可以清楚地发现,iris数据集是一个包含150个样本数据,4个特征向量,以及标签包含150个样本和3个分类的数据集。
1.2建立模型
下面开始建立决策树模型,导入决策树模块和一些辅助的模块:
# 处理数组和矩阵的模块
import numpy as np
# 处理画图的模块
import matplotlib.pyplot as plt
# 决策树模块
from sklearn.tree import DecisionTreeClassifier
# 定义所需要的参数
n_classes = 3
plot_colors = "ryb"
plot_step = 0.02
# 设置画布大小
plt.figure(figsize=(12, 6))
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]):
X = iris.data[:, pair]
y = iris.target
# 训练模型
clf = DecisionTreeClassifier().fit(X, y)
# 画图
plt.subplot(2, 3, pairidx + 1)
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(
np.arange(x_min, x_max, plot_step), np.arange(y_min, y_max, plot_step)
)
plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
# 预测
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu)
plt.xlabel(iris.feature_names[pair[0]])
plt.ylabel(iris.feature_names[pair[1]])
for i, color in zip(range(n_classes), plot_colors):
idx = np.where(y == i)
plt.scatter(
X[idx, 0],
X[idx, 1],
c=color,
label=iris.target_names[i],
cmap=plt.cm.RdYlBu,
edgecolor="black",
s=15,
)
plt.suptitle("Decision surface of decision trees trained on pairs of features")
plt.legend(loc="lower right", borderpad=0, handletextpad=0)
_ = plt.axis("tight")
1.3运行效果
上面的代码在jupyter lab里运行如下:
绘制的决策树决策边界如上图所示,我们可以简单总结如下:
- iris数据集在决策树上经过训练(fit)和预测(predict)后,绘制出3个决策面,因为iris数据集中包括3个分类;
- 3种不同的颜色对应3种不同的分类,通过3个不同的决策面把iris数据集划分成3个不同的类别;
- 仔细观察绘制的图例,我们可以发现大部分相同颜色的点都落在了相同颜色的面,红色点落在红色面中,蓝色点落在蓝色面中;这表示这部分样本被正确的分类了,也就是说大部分样本都被正确地分类了;
- 每个分类的决策面都存在一条清晰明确的分割线,把3个类别划分开;
- 当我们放大图例的图片时,我们可以看到有极个别的点落在了其他分类的决策面中,比如:有几个红色的点落在了蓝色的决策面中,这说明这部分少数的点被误分类了;
- 大部分点都落在了正确的决策面,少数部分点落在了其他决策面,这充分说明决策树DecisionTreeClassifier分类的准确率不是100%正确分类的,存在一定比例的错误分类;
- 少数点落在其他决策面的问题就是错误分类的问题;
1.4绘制决策树结构
下面通过简单的几行代码来可视化下iris数据集上,决策树的形状:
from sklearn.tree import plot_tree
plt.figure(figsize=(16, 9))
clf = DecisionTreeClassifier().fit(iris.data, iris.target)
plot_tree(clf,
filled=True,
rounded=True,
feature_names=iris.feature_names,
class_names=['setosa', 'versicolor', 'virginica'])
plt.title("Decision tree trained on all the iris features")
plt.savefig('DecisionTreeClassifier.jpg')
plt.show()
通过scikit-learn模块的plot_tree函数,我们可以轻松地绘制一棵决策树的树结构:
- 决策树有一个根节点,许多父节点和叶子节点;
- 每一个节点,不管是根节点,父节点还是叶子节点,都包含许多信息,例如:gini系数;因为决策树默认的分枝方法“不纯度”采用的是gini系数;
- 这棵决策树的节点不同颜色代表不同的分类类别,当节点的颜色越深的时候,该节点的gini系数也是最低的;
- 决策树的所有叶子节点的gini系数都为0,也就是说不纯度为0的时候我们就可以选择出标签的一个类别了;
1.5探索决策树结构
我们详细地探索下每一个节点都包含哪些数据,都有哪些含义:
- 每棵树节点都有5行数据,每行数据代表的意义是不一样的;
- 第一行数据表示对特征点进行提问,该节点分枝的左节点是对这个提问的YES回答,该节点分枝的右节点是对这个提问的NO回答;
- 第二行数据是不纯度指标,默认是gini系数的值;通过这棵决策树我们不难发现整个决策树从根节点到叶子节点gini系数是不断变小的;也就是说不纯度是降低的。
- 不纯度基于节点来计算,树中的每个节点都会有一个不纯度,并且子节点的不纯度一定是低于父节点的,也就是说,在同一棵决策树上,叶子节点的不纯度一定是最低的。
- 第三行数据samples表示该节点包含的样本数量;
- 第四行数据value表示标签的每一个类别所占的数量,因为iris数据集中标签包含3个类别,所以value是包含3个数值的数组;
- 第五行数据class表示该节点所在的分类类别;
我们再深入的探索下,在这棵决策树中哪些特征指标对这棵树起决定性作用呢?我们可以使用决策树的关键属性feature_importances_来看一下,哪些特征对这样一棵决策树的贡献最大。
clf.feature_importances_
[*zip(iris.feature_names,clf.feature_importances_)]
我们来简单分析下,iris数据集中一个样本包含4个特征数据,其中“petal length”花瓣的长度对这棵决策树的贡献是最大的,因为它的数值最大,数值越大对决策树的贡献也就越大;“sepal width”萼片宽度对这棵决策树的贡献最小,它的值为0。
1.6决策树的准确率
下面我们通过简单的几行代码来了解下决策树分类的准确率问题
from sklearn.model_selection import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(iris.data,iris.target,test_size=0.3)
clf = DecisionTreeClassifier()
clf = clf.fit(Xtrain, Ytrain)
#返回预测的准确度accuracy
score = clf.score(Xtest, Ytest)
score
运行效果如下:
可以通过决策树DecisionTreeClassifier的score函数获取决策树分类的准确率,通过以上代码我们可以非常容易的获取一棵默认参数的决策树在iris数据集上分类的准确率为95%以上;
在这里不得不提的是,所有接口中要求输入X_train和X_test的部分,输入的特征矩阵必须至少是一个二维矩阵。sklearn不接受任何一维矩阵作为特征矩阵被输入。如果你的数据的确只有一个特征,那必须用reshape(-1,1)来给矩阵增维;如果你的数据只有一个特征和一个样本,使用reshape(1,-1)来给你的数据增维。
1.7决策树随机性
random_state用来设置分枝中的随机模式的参数,默认None,在高维度时随机性会表现更明显,低维度的数据,随机性几乎不会显现。输入任意整数,会一直长出同一棵树,让模型稳定下来。
当我们反复运行上面决策树准确率的代码的时候,输出的score会每运行一次数值就会变化一次,这是因为我们建立的决策树模型没有设定任何参数,都是采用决策树默认的参数,那么当我们给决策树设定一个random_state看看效果如何?
clf = DecisionTreeClassifier(random_state=10)
clf = clf.fit(Xtrain, Ytrain)
#返回预测的准确度accuracy
score = clf.score(Xtest, Ytest)
score
当我们加上random_state参数后,无论运行多少次上面的代码,score的输出值就已经能够确定下来了,也就是说score的值不会再改变了;random_state的取值可以是任意整数值,它只是代表决策树的随机性能够被确定下来。
1.8决策树剪枝参数
在不加限制的情况下,一棵决策树会生长到衡量不纯度的指标最优,或者没有更多的特征可用为止。这样的决策树往往会过拟合,这就是说,它会在训练集上表现很好,但在测试集上却表现糟糕。我们收集的样本数据不可能和整体的状况完全一致,因此当一棵决策树对训练数据有了过于优秀的解释性,它找出的规则必然包含了训练样本中的噪声,并使它对未知数据的拟合程度不足。
clf = DecisionTreeClassifier(random_state=30
,max_depth=3
,min_samples_leaf=10
,min_samples_split=10
)
clf = clf.fit(iris.data, iris.target)
plt.figure(figsize=(16, 9))
plot_tree(clf,
filled=True,
rounded=True,
feature_names=iris.feature_names,
class_names=['setosa', 'versicolor', 'virginica'])
plt.savefig('cut_DecisionTreeClassifier.jpg')
plt.show()
剪枝后的决策树运行效果:
通过运行后生成的剪枝决策树结构与上面没有剪枝的决策树结构进行对比,我们可以清楚地发现,剪枝后的决策树的树结构更加扁平化了,树的深度降低了。具体参数解释如下:
- random_state用来设置分枝中的随机模式的参数,默认None;
- max_depth限制树的最大深度,超过设定深度的树枝全部剪掉;
- min_samples_leaf和min_samples_split表示min_samples_leaf限定一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生,或者,分枝会朝着满足每个子节点都包含min_samples_leaf个样本的方向去发生;
1.9确认最优的剪枝参数
那具体怎么来确定每个参数填写什么值呢?这时候,我们就要使用确定超参数的曲线来进行判断了,继续使用我们已经训练好的决策树模型clf。超参数的学习曲线,是一条以超参数的取值为横坐标,模型的度量指标为纵坐标的曲线,它是用来衡量不同超参数取值下模型的表现的线。在我们建好的决策树里,我们的模型度量指标就是score。
scores=[]
for i in range(10):
clf = DecisionTreeClassifier(max_depth=i+1,
random_state=10)
clf = clf.fit(Xtrain, Ytrain)
score= clf.score(Xtest, Ytest)
scores.append(score)
plt.plot(range(1,11),scores,color="red",label="max_depth")
plt.grid()
plt.legend()
plt.savefig('max_depth.jpg')
plt.show()
运行效果图:
通过上面的学习曲线,我们可以很容易地发现,当max_depth为3的时候,我们的这个决策树模型的准确率已经保持不变了,最大准确率保持在97%以上;也就是说这棵决策树在保持其他默认参数的情况下,树的最大深度为3的时候,这棵决策树的分类效果是最后的,分类准确率最高。
无论如何,剪枝参数的默认值会让树无尽地生长,这些树在某些数据集上可能非常巨大,对内存的消耗也非常巨大。所以如果你手中的数据集非常巨大,你已经预测到无论如何你都是要剪枝的,那提前设定这些参数来控制树的复杂性和大小会比较好。
至此,我们已经学完了决策树DecisionTreeClassifier和用决策树绘图(plot_tree)的所有基础。
不积跬步,无以至千里;
不积小流,无以成江海;
参考资料:
https://scikit-learn.org/stable/modules/tree.html
- 上一篇:python决策树-2
- 下一篇:使用Scikit-Learn了解决策树分类
相关推荐
- apisix动态修改路由的原理_动态路由协议rip的配置
-
ApacheAPISIX能够实现动态修改路由(DynamicRouting)的核心原理,是它将传统的静态Nginx配置彻底解耦,通过中心化配置存储(如etcd)+OpenRest...
- 使用 Docker 部署 OpenResty Manager 搭建可视化反向代理系统
-
在之前的文章中,xiaoz推荐过可视化Nginx反向代理工具NginxProxyManager,最近xiaoz还发现一款功能更加强大,界面更加漂亮的OpenRestyManager,完全可以替代...
- OpenResty 入门指南:从基础到动态路由实战
-
一、引言1.1OpenResty简介OpenResty是一款基于Nginx的高性能Web平台,通过集成Lua脚本和丰富的模块,将Nginx从静态反向代理转变为可动态编程的应用平台...
- OpenResty 的 Lua 动态能力_openresty 动态upstream
-
OpenResty的Lua动态能力是其最核心的优势,它将LuaJIT嵌入到Nginx的每一个请求处理阶段,使得开发者可以用Lua脚本动态控制请求的生命周期,而无需重新编译或rel...
- LVS和Nginx_lvs和nginx的区别
-
LVS(LinuxVirtualServer)和Nginx都是常用的负载均衡解决方案,广泛应用于大型网站和分布式系统中,以提高系统的性能、可用性和可扩展性。一、基本概念1.LVS(Linux...
- 外网连接到内网服务器需要端口映射吗,如何操作?
-
外网访问内网服务器通常需要端口映射(或内网穿透),这是跨越公网与私网边界的关键技术。操作方式取决于网络环境,以下分场景详解。一、端口映射的核心原理内网服务器位于私有IP地址段(如192.168.x.x...
- Nginx如何解决C10K问题(1万个并发连接)?
-
关注△mikechen△,十余年BAT架构经验倾囊相授!大家好,我是mikechen。Nginx是大型架构的必备中间件,下面我就全面来详解NginxC10k问题@mikechen文章来源:mikec...
- 炸场!Spring Boot 9 大内置过滤器实战手册:从坑到神
-
炸场!SpringBoot9大内置过滤器实战手册:从坑到神在Java开发圈摸爬滚打十年,见过太多团队重复造轮子——明明SpringBoot自带的过滤器就能解决的问题,偏偏要手写几十...
- WordPress和Typecho xmlrpc漏洞_wordpress主题漏洞
-
一般大家都关注WordPress,毕竟用户量巨大,而国内的Typecho作为轻量级的博客系统就关注的人并不多。Typecho有很多借鉴WordPress的,包括兼容的xmlrpc接口,而WordPre...
- Linux Shell 入门教程(六):重定向、管道与命令替换
-
在前几篇中,我们学习了函数、流程控制等Shell编程的基础内容。现在我们来探索更高级的功能:如何控制数据流向、将命令链接在一起、让命令间通信变得可能。一、输入输出重定向(>、>>...
- Nginx的location匹配规则,90%的人都没完全搞懂,一张图让你秒懂
-
刚配完nginx网站就崩了?运维和开发都头疼的location匹配规则优先级,弄错顺序直接导致500错误。核心在于nginx处理location时顺序严格:先精确匹配=,然后前缀匹配^~,接着按顺序正...
- liunx服务器查看故障命令有那些?_linux查看服务器性能命令
-
在Linux服务器上排查故障时,需要使用一系列命令来检查系统状态、日志文件、资源利用情况以及网络状况。以下是常用的故障排查命令,按照不同场景分类说明。1.系统资源相关命令1.1查看CPU使...
- 服务器被入侵的常见迹象有哪些?_服务器入侵可以被完全操纵吗
-
服务器被入侵可能会导致数据泄露、服务异常或完全失控。及时发现入侵迹象能够帮助你尽早采取措施,减少损失。以下是服务器被入侵的常见迹象以及相关的分析与处理建议。1.服务器被入侵的常见迹象1.1系统性能...
- 前端错误可观测最佳实践_前端错误提示
-
场景解析对于前端项目,生产环境的代码通常经过压缩、混淆和打包处理,当代码在运行过程中产生错误时,通常难以还原原始代码从而定位问题,对于深度混淆尤其如此,因此Mozilla自2011年开始发起并...
- 8个能让你的Kubernetes集群“瞬间崩溃”的配置错误
-
错误一:livenessProbe探针“自杀式”配置——30秒内让Pod重启20次现象:Pod状态在Running→Terminating→CrashLoopBackOff之间循环,重启间隔仅...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- apisix动态修改路由的原理_动态路由协议rip的配置
- 使用 Docker 部署 OpenResty Manager 搭建可视化反向代理系统
- OpenResty 入门指南:从基础到动态路由实战
- OpenResty 的 Lua 动态能力_openresty 动态upstream
- LVS和Nginx_lvs和nginx的区别
- 外网连接到内网服务器需要端口映射吗,如何操作?
- Nginx如何解决C10K问题(1万个并发连接)?
- 炸场!Spring Boot 9 大内置过滤器实战手册:从坑到神
- WordPress和Typecho xmlrpc漏洞_wordpress主题漏洞
- Linux Shell 入门教程(六):重定向、管道与命令替换
- 标签列表
-
- python计时 (73)
- python安装路径 (56)
- python类型转换 (93)
- python进度条 (67)
- python吧 (67)
- python的for循环 (65)
- python格式化字符串 (61)
- python静态方法 (57)
- python列表切片 (59)
- python面向对象编程 (60)
- python 代码加密 (65)
- python串口编程 (77)
- python封装 (57)
- python写入txt (66)
- python读取文件夹下所有文件 (59)
- python操作mysql数据库 (66)
- python获取列表的长度 (64)
- python接口 (63)
- python调用函数 (57)
- python多态 (60)
- python匿名函数 (59)
- python打印九九乘法表 (65)
- python赋值 (62)
- python异常 (69)
- python元祖 (57)