Python机器学习系列之scikit-learn决策树分类问题简单实践
off999 2024-11-26 07:24 17 浏览 0 评论
1.决策树分类问题实践
在前面的一个章节中,我们简要地概述了一下决策树的原理知识,了解了一下决策树分支原理,调参过程以及可视化决策树:
在决策树DecisionTree中,DecisionTreeClassifier是能够处理一些分类问题的。与其他分类器一样,DecisionTreeClassifier将两个数组作为输入:一个数组X,稀疏或者密集,shape (n_samples, n_features),以及一个整数值数组Y,shape (n_samples,)
from sklearn import tree
X = [[0, 0], [1, 1]]
Y = [0, 1]
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)
拟合后,该模型可用于预测样本类别:
clf.predict([[2., 2.]])
array([1])
如果存在多个具有相同且最高概率的类,分类器将预测这些类中具有最低索引的类。
作为输出特点类的替代方法,可以预测每个类的概率,即该类在叶子节点中的训练样本的分数:
clf.predict_proba([[2., 2.]])
array([[0., 1.]])
DecisionTreeClassifier能够进行二元(其中标签为[-1,1])分类和多分类(其中标签为[0,.....,k-1])分类。
下面通过一个简单的官方示例来了解下,绘制根据iris数据集的一对特征进行训练的决策树的决策面。
对于每一对iris数据集特征,决策树学习由训练样本推断的简单阈值规则组合构成的决策边界。
1.1 加载数据集
首先加载scikit-learn附带的iris数据集样本:
from sklearn.datasets import load_iris
iris = load_iris()
通过数据集的属性简单探索下iris数据集的结构:
iris.data.shape
(150, 4)
通过shape属性我们可以清楚地看到,iris数据集是一个二维矩阵,数据集中有150个样本,每一个样本包含4个特征属性;
iris.target.shape
(150,)
其实,标签数据是一个一维矩阵,标签中也有150个样本;从这里能够充分说明iris数据集中,特征的样本数与标签的样本数是一样的,都是150个样本。
import numpy as np
np.unique(iris.target)
array([0, 1, 2])
iris.target_names
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
通过以上代码,我们可以验证标签数据集中包含3个类别,分别是:setosa,versicolor和virginica。
import pandas as pd
csv = pd.concat([pd.DataFrame(iris.data, columns=iris.feature_names), pd.DataFrame(iris.target, columns=['target'])], axis=1)
csv
csv.head()
通过pandas将数据集转化成csv格式,调用head()函数,我们可以清楚地看到iris数据集的全貌,经过简单的探索,我们可以清楚地发现,iris数据集是一个包含150个样本数据,4个特征向量,以及标签包含150个样本和3个分类的数据集。
1.2建立模型
下面开始建立决策树模型,导入决策树模块和一些辅助的模块:
# 处理数组和矩阵的模块
import numpy as np
# 处理画图的模块
import matplotlib.pyplot as plt
# 决策树模块
from sklearn.tree import DecisionTreeClassifier
# 定义所需要的参数
n_classes = 3
plot_colors = "ryb"
plot_step = 0.02
# 设置画布大小
plt.figure(figsize=(12, 6))
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]):
X = iris.data[:, pair]
y = iris.target
# 训练模型
clf = DecisionTreeClassifier().fit(X, y)
# 画图
plt.subplot(2, 3, pairidx + 1)
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(
np.arange(x_min, x_max, plot_step), np.arange(y_min, y_max, plot_step)
)
plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
# 预测
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu)
plt.xlabel(iris.feature_names[pair[0]])
plt.ylabel(iris.feature_names[pair[1]])
for i, color in zip(range(n_classes), plot_colors):
idx = np.where(y == i)
plt.scatter(
X[idx, 0],
X[idx, 1],
c=color,
label=iris.target_names[i],
cmap=plt.cm.RdYlBu,
edgecolor="black",
s=15,
)
plt.suptitle("Decision surface of decision trees trained on pairs of features")
plt.legend(loc="lower right", borderpad=0, handletextpad=0)
_ = plt.axis("tight")
1.3运行效果
上面的代码在jupyter lab里运行如下:
绘制的决策树决策边界如上图所示,我们可以简单总结如下:
- iris数据集在决策树上经过训练(fit)和预测(predict)后,绘制出3个决策面,因为iris数据集中包括3个分类;
- 3种不同的颜色对应3种不同的分类,通过3个不同的决策面把iris数据集划分成3个不同的类别;
- 仔细观察绘制的图例,我们可以发现大部分相同颜色的点都落在了相同颜色的面,红色点落在红色面中,蓝色点落在蓝色面中;这表示这部分样本被正确的分类了,也就是说大部分样本都被正确地分类了;
- 每个分类的决策面都存在一条清晰明确的分割线,把3个类别划分开;
- 当我们放大图例的图片时,我们可以看到有极个别的点落在了其他分类的决策面中,比如:有几个红色的点落在了蓝色的决策面中,这说明这部分少数的点被误分类了;
- 大部分点都落在了正确的决策面,少数部分点落在了其他决策面,这充分说明决策树DecisionTreeClassifier分类的准确率不是100%正确分类的,存在一定比例的错误分类;
- 少数点落在其他决策面的问题就是错误分类的问题;
1.4绘制决策树结构
下面通过简单的几行代码来可视化下iris数据集上,决策树的形状:
from sklearn.tree import plot_tree
plt.figure(figsize=(16, 9))
clf = DecisionTreeClassifier().fit(iris.data, iris.target)
plot_tree(clf,
filled=True,
rounded=True,
feature_names=iris.feature_names,
class_names=['setosa', 'versicolor', 'virginica'])
plt.title("Decision tree trained on all the iris features")
plt.savefig('DecisionTreeClassifier.jpg')
plt.show()
通过scikit-learn模块的plot_tree函数,我们可以轻松地绘制一棵决策树的树结构:
- 决策树有一个根节点,许多父节点和叶子节点;
- 每一个节点,不管是根节点,父节点还是叶子节点,都包含许多信息,例如:gini系数;因为决策树默认的分枝方法“不纯度”采用的是gini系数;
- 这棵决策树的节点不同颜色代表不同的分类类别,当节点的颜色越深的时候,该节点的gini系数也是最低的;
- 决策树的所有叶子节点的gini系数都为0,也就是说不纯度为0的时候我们就可以选择出标签的一个类别了;
1.5探索决策树结构
我们详细地探索下每一个节点都包含哪些数据,都有哪些含义:
- 每棵树节点都有5行数据,每行数据代表的意义是不一样的;
- 第一行数据表示对特征点进行提问,该节点分枝的左节点是对这个提问的YES回答,该节点分枝的右节点是对这个提问的NO回答;
- 第二行数据是不纯度指标,默认是gini系数的值;通过这棵决策树我们不难发现整个决策树从根节点到叶子节点gini系数是不断变小的;也就是说不纯度是降低的。
- 不纯度基于节点来计算,树中的每个节点都会有一个不纯度,并且子节点的不纯度一定是低于父节点的,也就是说,在同一棵决策树上,叶子节点的不纯度一定是最低的。
- 第三行数据samples表示该节点包含的样本数量;
- 第四行数据value表示标签的每一个类别所占的数量,因为iris数据集中标签包含3个类别,所以value是包含3个数值的数组;
- 第五行数据class表示该节点所在的分类类别;
我们再深入的探索下,在这棵决策树中哪些特征指标对这棵树起决定性作用呢?我们可以使用决策树的关键属性feature_importances_来看一下,哪些特征对这样一棵决策树的贡献最大。
clf.feature_importances_
[*zip(iris.feature_names,clf.feature_importances_)]
我们来简单分析下,iris数据集中一个样本包含4个特征数据,其中“petal length”花瓣的长度对这棵决策树的贡献是最大的,因为它的数值最大,数值越大对决策树的贡献也就越大;“sepal width”萼片宽度对这棵决策树的贡献最小,它的值为0。
1.6决策树的准确率
下面我们通过简单的几行代码来了解下决策树分类的准确率问题
from sklearn.model_selection import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(iris.data,iris.target,test_size=0.3)
clf = DecisionTreeClassifier()
clf = clf.fit(Xtrain, Ytrain)
#返回预测的准确度accuracy
score = clf.score(Xtest, Ytest)
score
运行效果如下:
可以通过决策树DecisionTreeClassifier的score函数获取决策树分类的准确率,通过以上代码我们可以非常容易的获取一棵默认参数的决策树在iris数据集上分类的准确率为95%以上;
在这里不得不提的是,所有接口中要求输入X_train和X_test的部分,输入的特征矩阵必须至少是一个二维矩阵。sklearn不接受任何一维矩阵作为特征矩阵被输入。如果你的数据的确只有一个特征,那必须用reshape(-1,1)来给矩阵增维;如果你的数据只有一个特征和一个样本,使用reshape(1,-1)来给你的数据增维。
1.7决策树随机性
random_state用来设置分枝中的随机模式的参数,默认None,在高维度时随机性会表现更明显,低维度的数据,随机性几乎不会显现。输入任意整数,会一直长出同一棵树,让模型稳定下来。
当我们反复运行上面决策树准确率的代码的时候,输出的score会每运行一次数值就会变化一次,这是因为我们建立的决策树模型没有设定任何参数,都是采用决策树默认的参数,那么当我们给决策树设定一个random_state看看效果如何?
clf = DecisionTreeClassifier(random_state=10)
clf = clf.fit(Xtrain, Ytrain)
#返回预测的准确度accuracy
score = clf.score(Xtest, Ytest)
score
当我们加上random_state参数后,无论运行多少次上面的代码,score的输出值就已经能够确定下来了,也就是说score的值不会再改变了;random_state的取值可以是任意整数值,它只是代表决策树的随机性能够被确定下来。
1.8决策树剪枝参数
在不加限制的情况下,一棵决策树会生长到衡量不纯度的指标最优,或者没有更多的特征可用为止。这样的决策树往往会过拟合,这就是说,它会在训练集上表现很好,但在测试集上却表现糟糕。我们收集的样本数据不可能和整体的状况完全一致,因此当一棵决策树对训练数据有了过于优秀的解释性,它找出的规则必然包含了训练样本中的噪声,并使它对未知数据的拟合程度不足。
clf = DecisionTreeClassifier(random_state=30
,max_depth=3
,min_samples_leaf=10
,min_samples_split=10
)
clf = clf.fit(iris.data, iris.target)
plt.figure(figsize=(16, 9))
plot_tree(clf,
filled=True,
rounded=True,
feature_names=iris.feature_names,
class_names=['setosa', 'versicolor', 'virginica'])
plt.savefig('cut_DecisionTreeClassifier.jpg')
plt.show()
剪枝后的决策树运行效果:
通过运行后生成的剪枝决策树结构与上面没有剪枝的决策树结构进行对比,我们可以清楚地发现,剪枝后的决策树的树结构更加扁平化了,树的深度降低了。具体参数解释如下:
- random_state用来设置分枝中的随机模式的参数,默认None;
- max_depth限制树的最大深度,超过设定深度的树枝全部剪掉;
- min_samples_leaf和min_samples_split表示min_samples_leaf限定一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生,或者,分枝会朝着满足每个子节点都包含min_samples_leaf个样本的方向去发生;
1.9确认最优的剪枝参数
那具体怎么来确定每个参数填写什么值呢?这时候,我们就要使用确定超参数的曲线来进行判断了,继续使用我们已经训练好的决策树模型clf。超参数的学习曲线,是一条以超参数的取值为横坐标,模型的度量指标为纵坐标的曲线,它是用来衡量不同超参数取值下模型的表现的线。在我们建好的决策树里,我们的模型度量指标就是score。
scores=[]
for i in range(10):
clf = DecisionTreeClassifier(max_depth=i+1,
random_state=10)
clf = clf.fit(Xtrain, Ytrain)
score= clf.score(Xtest, Ytest)
scores.append(score)
plt.plot(range(1,11),scores,color="red",label="max_depth")
plt.grid()
plt.legend()
plt.savefig('max_depth.jpg')
plt.show()
运行效果图:
通过上面的学习曲线,我们可以很容易地发现,当max_depth为3的时候,我们的这个决策树模型的准确率已经保持不变了,最大准确率保持在97%以上;也就是说这棵决策树在保持其他默认参数的情况下,树的最大深度为3的时候,这棵决策树的分类效果是最后的,分类准确率最高。
无论如何,剪枝参数的默认值会让树无尽地生长,这些树在某些数据集上可能非常巨大,对内存的消耗也非常巨大。所以如果你手中的数据集非常巨大,你已经预测到无论如何你都是要剪枝的,那提前设定这些参数来控制树的复杂性和大小会比较好。
至此,我们已经学完了决策树DecisionTreeClassifier和用决策树绘图(plot_tree)的所有基础。
不积跬步,无以至千里;
不积小流,无以成江海;
参考资料:
https://scikit-learn.org/stable/modules/tree.html
- 上一篇:python决策树-2
- 下一篇:使用Scikit-Learn了解决策树分类
相关推荐
- 面试官:来,讲一下枚举类型在开发时中实际应用场景!
-
一.基本介绍枚举是JDK1.5新增的数据类型,使用枚举我们可以很好的描述一些特定的业务场景,比如一年中的春、夏、秋、冬,还有每周的周一到周天,还有各种颜色,以及可以用它来描述一些状态信息,比如错...
- 一日一技:11个基本Python技巧和窍门
-
1.两个数字的交换.x,y=10,20print(x,y)x,y=y,xprint(x,y)输出:102020102.Python字符串取反a="Ge...
- Python Enum 技巧,让代码更简洁、更安全、更易维护
-
如果你是一名Python开发人员,你很可能使用过enum.Enum来创建可读性和可维护性代码。今天发现一个强大的技巧,可以让Enum的境界更进一层,这个技巧不仅能提高可读性,还能以最小的代价增...
- Python元组编程指导教程(python元组的概念)
-
1.元组基础概念1.1什么是元组元组(Tuple)是Python中一种不可变的序列类型,用于存储多个有序的元素。元组与列表(list)类似,但元组一旦创建就不能修改(不可变),这使得元组在某些场景...
- 你可能不知道的实用 Python 功能(python有哪些用)
-
1.超越文件处理的内容管理器大多数开发人员都熟悉使用with语句进行文件操作:withopen('file.txt','r')asfile:co...
- Python 2至3.13新特性总结(python 3.10新特性)
-
以下是Python2到Python3.13的主要新特性总结,按版本分类整理:Python2到Python3的重大变化Python3是一个不向后兼容的版本,主要改进包括:pri...
- Python中for循环访问索引值的方法
-
技术背景在Python编程中,我们经常需要在循环中访问元素的索引值。例如,在处理列表、元组等可迭代对象时,除了要获取元素本身,还需要知道元素的位置。Python提供了多种方式来实现这一需求,下面将详细...
- Python enumerate核心应用解析:索引遍历的高效实践方案
-
喜欢的条友记得关注、点赞、转发、收藏,你们的支持就是我最大的动力源泉。根据GitHub代码分析统计,使用enumerate替代range(len())写法可减少38%的索引错误概率。本文通过12个生产...
- Python入门到脱坑经典案例—列表去重
-
列表去重是Python编程中常见的操作,下面我将介绍多种实现列表去重的方法,从基础到进阶,帮助初学者全面掌握这一技能。方法一:使用集合(set)去重(最简单)pythondefremove_dupl...
- Python枚举类工程实践:常量管理的标准化解决方案
-
本文通过7个生产案例,系统解析枚举类在工程实践中的应用,覆盖状态管理、配置选项、错误代码等场景,适用于Web服务开发、自动化测试及系统集成领域。一、基础概念与语法演进1.1传统常量与枚举类对比#传...
- 让Python枚举更强大!教你玩转Enum扩展
-
为什么你需要关注Enum?在日常开发中,你是否经常遇到这样的代码?ifstatus==1:print("开始处理")elifstatus==2:pri...
- Python枚举(Enum)技巧,你值得了解
-
枚举(Enum)提供了更清晰、结构化的方式来定义常量。通过为枚举添加行为、自动分配值和存储额外数据,可以提升代码的可读性、可维护性,并与数据库结合使用时,使用字符串代替数字能简化调试和查询。Pytho...
- 78行Python代码帮你复现微信撤回消息!
-
来源:悟空智能科技本文约700字,建议阅读5分钟。本文基于python的微信开源库itchat,教你如何收集私聊撤回的信息。[导读]Python曾经对我说:"时日不多,赶紧用Python"。于是看...
- 登录人人都是产品经理即可获得以下权益
-
文章介绍如何利用Cursor自动开发Playwright网页自动化脚本,实现从选题、写文、生图的全流程自动化,并将其打包成API供工作流调用,提高工作效率。虽然我前面文章介绍了很多AI工作流,但它们...
- Python常用小知识-第二弹(python常用方法总结)
-
一、Python中使用JsonPath提取字典中的值JsonPath是解析Json字符串用的,如果有一个多层嵌套的复杂字典,想要根据key和下标来批量提取value,这是比较困难的,使用jsonpat...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- python计时 (73)
- python安装路径 (56)
- python类型转换 (93)
- python自定义函数 (53)
- python进度条 (67)
- python吧 (67)
- python字典遍历 (54)
- python的for循环 (65)
- python格式化字符串 (61)
- python串口编程 (60)
- python读取文件夹下所有文件 (59)
- java调用python脚本 (56)
- python操作mysql数据库 (66)
- python字典增加键值对 (53)
- python获取列表的长度 (64)
- python接口 (63)
- python调用函数 (57)
- python人脸识别 (54)
- python多态 (60)
- python命令行参数 (53)
- python匿名函数 (59)
- python打印九九乘法表 (65)
- python赋值 (62)
- python异常 (69)
- python元祖 (57)