资深大佬教你如何利用PyTorch实现图像识别(图文详解)
off999 2025-05-30 16:55 70 浏览 0 评论
这篇文章主要给大家介绍了关于如何利用PyTorch实现图像识别的相关资料,文中通过图文以及实例代码介绍的非常详细,对大家学习或者使用PyTorch具有一定的参考学习价值,需要的朋友可以参考下
目录
- 使用torchvision库的datasets类加载常用的数据集或自定义数据集
- 使用torchvision库进行数据增强和变换,自定义自己的图像分类数据集并使用torchvision库加载它们
- 使用torchvision库的models类加载预训练模型或自定义模型
- forward方法
- 总结
使用torchvision库的datasets类加载常用的数据集或自定义数据集
图像识别是计算机视觉中的一个基础任务,它的目标是让计算机能够识别图像中的物体、场景或者概念,并将它们分配到预定义的类别中。例如,给定一张猫的图片,图像识别系统应该能够输出“猫”这个类别。
为了训练和评估图像识别系统,我们需要有大量的带有标注的图像数据集。常用的图像分类数据集有:
- ImageNet:一个包含超过1400万张图片和2万多个类别的大型数据库,是目前最流行和最具挑战性的图像分类基准之一。
- CIFAR-10/CIFAR-100:一个包含6万张32×32大小的彩色图片和10或100个类别的小型数据库,适合入门级和快速实验。
- MNIST:一个包含7万张28×28大小的灰度手写数字图片和10个类别的经典数据库,是深度学习中最常用的测试集之一。
- Fashion-MNIST:一个包含7万张28×28大小的灰度服装图片和10个类别的数据库,是MNIST数据库在时尚领域上更加复杂和现代化版本。
使用torchvision库可以方便地加载这些常用数据集或者自定义数据集。torchvision.datasets提供了一些加载数据集或者下载数据集到本地缓存文件夹(默认为./data)并返回Dataset对象(torch.utils.data.Dataset) 的函数。Dataset对象可以存储样本及其对应标签,并提供索引方式(dataset[i])来获取第i个样本。例如,要加载CIFAR-10训练集并进行随机打乱,可以使用以下代码:
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.ToTensor()]) # 定义转换函数,将PIL.Image转换为torch.Tensor
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) # 加载CIFAR-10训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True) # 定义DataLoader对象,用于批量加载数据使用torchvision库进行数据增强和变换,自定义自己的图像分类数据集并使用torchvision库加载它们
数据增强和变换:为了提高模型的泛化能力和数据利用率,我们通常会对图像数据进行一些随机的变换,例如裁剪、旋转、翻转、缩放、亮度调整等。这些变换可以在一定程度上模拟真实场景中的图像变化,增加模型对不同视角和光照条件下的物体识别能力。torchvision.transforms提供了一些常用的图像变换函数,可以组合成一个transform对象,并传入datasets类中作为参数。例如,要对CIFAR-10训练集进行随机水平翻转和随机裁剪,并将图像归一化到[-1, 1]范围内,可以使用以下代码:
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomCrop(32, padding=4), # 随机裁剪到32×32大小,并在边缘填充4个像素
transforms.ToTensor(), # 将PIL.Image转换为torch.Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 将RGB三个通道的值归一化到[-1, 1]范围内
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) # 加载CIFAR-10训练集,并应用上述变换
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True) # 定义DataLoader对象,用于批量加载数据自定义图像分类数据集:如果我们有自己的图像分类数据集,我们可以通过继承torch.utils.data.Dataset类来自定义一个Dataset对象,并实现__len__和__getitem__两个方法。__len__方法返回数据集中样本的数量,__getitem__方法根据给定的索引返回一个样本及其标签。例如,假设我们有一个文件夹结构如下:
my_dataset/
├── class_0/
│ ├── image_000.jpg
│ ├── image_001.jpg
│ └── ...
├── class_1/
│ ├── image_000.jpg
│ ├── image_001.jpg
│ └── ...
└── ...
其中每个子文件夹代表一个类别,每个子文件夹中包含该类别对应的图像文件。我们可以使用以下代码来自定义一个Dataset对象,并加载这个数据集:
import torch.utils.data as data
from PIL import Image
import os
class MyDataset(data.Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir # 根目录路径
self.transform = transform # 变换函数
self.classes = sorted(os.listdir(root_dir)) # 类别列表(按字母顺序排序)
self.class_to_idx = {c: i for i,c in enumerate(self.classes)} # 类别名到索引的映射
self.images = [] # 图片路径列表(相对于根目录)
self.labels = [] # 标签列表(整数)
for c in self.classes:
c_dir = os.path.join(root_dir, c) # 类别子目录路径
for img_name in sorted(os.listdir(c_dir)): # 遍历每个图片文件名(按字母顺序排序)
img_path = os.path.join(c,img_name) # 图片相对路径(相对于根目录)
label = self.class_to_idx[c] # 图使用torchvision库的models类加载预训练模型或自定义模型
加载预训练模型或自定义模型:torchvision.models提供了一些常用的图像分类模型,例如AlexNet、VGG、ResNet等,并且可以选择是否加载在ImageNet数据集上预训练好的权重。这些模型可以直接用于图像分类任务,也可以作为特征提取器或者微调(fine-tune)的基础。例如,要加载一个预训练好的ResNet-18模型,并冻结除最后一层外的所有参数,可以使用以下代码:
import torchvision.models as models
model = models.resnet18(pretrained=True) # 加载预训练好的ResNet-18模型
for param in model.parameters(): # 遍历所有参数
param.requires_grad = False # 将参数的梯度设置为False,表示不需要更新
num_features = model.fc.in_features # 获取全连接层(fc)的输入特征数
model.fc = torch.nn.Linear(num_features, 10) # 替换全连接层为一个新的线性层,输出特征数为10(假设有10个类别)如果我们想要自定义自己的图像分类模型,我们可以通过继承torch.nn.Module类来实现一个Module对象,并实现__init__和forward两个方法。__init__方法用于定义模型中需要的各种层和参数,forward方法用于定义前向传播过程。例如,要自定义一个简单的卷积神经网络(CNN)模型,可以使用以下代码:
import torch.nn as nn
class MyCNN(nn.Module):
def __init__(self):
super(MyCNN, self).__init__() # 调用父类构造函数
self.conv1 = nn.Conv2d(3, 6, 5) # 定义第一个卷积层,输入通道数为3(RGB),输出通道数为6,卷积核大小为5×5
self.pool = nn.MaxPool2d(2, 2) # 定义最大池化层,池化核大小为2×2,步长为2
self.conv2 = nn.Conv2d(6, 16, 5) # 定义第二个卷积层,输入通道数为6,输出通道数为16,卷积核大小为5×5
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 定义第一个全连接层,输入特征数为16×5×5(根据卷积和池化后的图像大小计算得到),输出特征数为120
self.fc2 = nn.Linear(120, 84) # 定义第二个全连接层,输入特征数为120,输出特征数为84
self.fc3 = nn.Linear(84, 10) # 定义第三个全连接层,输入特征数为84,forward方法
forward方法用于定义前向传播过程,即如何根据输入的图像张量(Tensor)计算出输出的类别概率分布。我们可以使用定义好的各种层和参数,并结合一些激活函数(如ReLU)和归一化函数(如softmax)来实现forward方法。例如,要实现上面自定义的CNN模型的forward方法,可以使用以下代码:
import torch.nn.functional as F
class MyCNN(nn.Module):
def __init__(self):
# 省略__init__方法的内容
...
def forward(self, x): # 定义前向传播过程,x是输入的图像张量
x = self.pool(F.relu(self.conv1(x))) # 将x通过第一个卷积层和ReLU激活函数,然后通过最大池化层
x = self.pool(F.relu(self.conv2(x))) # 将x通过第二个卷积层和ReLU激活函数,然后通过最大池化层
x = x.view(-1, 16 * 5 * 5) # 将x展平为一维向量,-1表示自动推断批量大小
x = F.relu(self.fc1(x)) # 将x通过第一个全连接层和ReLU激活函数
x = F.relu(self.fc2(x)) # 将x通过第二个全连接层和ReLU激活函数
x = self.fc3(x) # 将x通过第三个全连接层
x = F.softmax(x, dim=1) # 将x通过softmax函数,沿着第一个维度(类别维度)进行归一化,得到类别概率分布
return x # 返回输出的类别概率分布进行模型训练和测试,使用matplotlib.pyplot库可视化结果
模型训练和测试是机器学习中的重要步骤,它们可以帮助我们评估模型的性能和泛化能力。matplotlib.pyplot是一个Python库,它可以用来绘制各种类型的图形,包括曲线图、散点图、直方图等。使用matplotlib.pyplot库可视化结果的一般步骤如下:
- 导入matplotlib.pyplot模块,并设置一些参数,如字体、分辨率等。
- 创建一个或多个图形对象(figure),并指定大小、标题等属性。
- 在每个图形对象中创建一个或多个子图(subplot),并指定位置、坐标轴等属性。
- 在每个子图中绘制数据,使用不同的函数和参数,如plot、scatter、bar等。
- 添加一些修饰元素,如图例(legend)、标签(label)、标题(title)等。
- 保存或显示图形。
例如:使用matplotlib.pyplot库绘制了一个线性回归模型的训练误差和测试误差曲线:
# 导入模块
import matplotlib.pyplot as plt
import numpy as np
# 设置字体和分辨率
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
%config InlineBackend.figure_format = "retina"
# 生成数据
x = np.linspace(0, 10, 100)
y = 3 * x + 5 + np.random.randn(100) * 2 # 真实值
w = np.random.randn() # 随机初始化权重
b = np.random.randn() # 随机初始化偏置
# 定义损失函数
def loss(y_true, y_pred):
return ((y_true - y_pred) ** 2).mean()
# 定义梯度下降函数
def gradient_descent(x, y_true, w, b, lr):
y_pred = w * x + b # 预测值
dw = -2 * (x * (y_true - y_pred)).mean() # 权重梯度
db = -2 * (y_true - y_pred).mean() # 偏置梯度
w = w - lr * dw # 更新权重
b = b - lr * db # 更新偏置
return w, b
# 训练模型,并记录每轮的训练误差和测试误差
epochs = 20 # 训练轮数
lr = 0.01 # 学习率
train_loss_list = [] # 训练误差列表
test_loss_list = [] # 测试误差列表
for epoch in range(epochs):
# 划分训练集和测试集(8:2)
train_index = np.random.choice(100, size=80, replace=False)
test_index = np.setdiff1d(np.arange(100), train_index)
x_train, y_train = x[train_index], y[train_index]
x_test, y_test = x[test_index], y[test_index]
# 梯度下降更新参数,并计算训练误差和测试误差
w, b = gradient_descent(x_train, y_train, w, b, lr)
train_loss = loss(y_train, w * x_train + b)
test_loss = loss(y_test, w * x_test + b)
# 打印结果,并将误差添加到列表中
print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
train_loss_list.append(train_loss)
test_loss_list.append(test_loss)
# 创建一个图形对象,并设置大小为8*6英寸
plt.figure(figsize=(8,6))
# 在图形对象中创建一个子图,并设置位置为1行1列的第1个
plt.subplot(1, 1, 1)
# 在子图中绘制训练误差和测试误差曲线,使用不同的颜色和标签
plt.plot(np.arange(epochs), train_loss_list, "r", label="Train Loss")
plt.plot(np.arange(epochs), test_loss_list, "b", label="Test Loss")
# 添加图例、坐标轴标签和标题
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Linear Regression Loss Curve")
# 保存或显示图形
#plt.savefig("loss_curve.png")
plt.show()运行后,可以看到如下的图形:
参考:: PyTorch官方网站
总结
到此这篇关于如何利用PyTorch实现图像识别的文章就介绍到这了,更多相关PyTorch图像识别内容请搜索小编以前的文章或继续浏览下面的相关文章希望大家以后多多支持小编!
相关推荐
- usb系统盘下载(系统u盘之家)
-
手机不可以下载电脑系统到U盘里,这是跟系统文件的格式有直接关系。电脑的系统文件,它在下载安装的时候必须使用电脑版本的U盘才可以正确安装。手机的版本它和电脑的版本差别比较大,即使下载后也不可能正确安装。...
- windows8模拟器(国内版)(win8模拟器安卓版下载)
-
雷电模拟器能在win8系统运行,1、官网下载雷电模拟器,双击安装包进入安装界面。2、点击“自定义安装”修改安装路径,点击“浏览”选择好要安装的路径,默认勾选“已同意”,最后点击“立即安装”。...
- win10安装专业版还是家庭版(win10安装专业版还是家庭版好)
-
从Win10家庭版和专业版对比来看,Win10专业版要比家庭版功能更强大一些,不过价格更贵。另外Win10专业版的一系列Win10增强技术对于普通用户也基本用不到,多了也显得系统不那么精简,因此普通个...
- win10系统保护不见了(win10系统保护打不开怎么办)
-
1、启动计算机,启动到Windows10开机LOGO时就按住电源键强制关机,重复强制关机3次!2、重复步骤3次左右启动后出现“自动修复”界面,我们点击高级选项进入;3、接下来会到选择一个选项界面...
- 新手如何重装win8(怎么重新装系统win8)
-
要想重装回win8.1系统,首先你需要一个win8.1的系统安装盘,然后把你电脑的系统盘格式化一下,或者把你的win10系统删除了,再把win8.1系统安装盘插到电脑上,进行系统安装,等电脑安装系统完...
- 磁盘分区工具软件(硬盘分区工具软件)
-
如果说最安全的那就用电脑自带的吧,右键我的电脑,找到管理,然后进去磁盘管理,然后找到目前的一个磁盘,右键压缩卷,输入压缩空间就是你想要的一个盘的大小(1G=1024MB),然后压缩,然后找到你压缩出来...
- ftp手机客户端(ftp手机客户端存文件)
-
要想实现FTP文件传输,必须在相连的两端都装有支持FTP协议的软件,装在您的电脑上的叫FTP客户端软件,装在另一端服务器上的叫做FTP服务器端软件。 客户端FTP软件使用方法很简单,启动后首先要与...
- 原版xp系统镜像(原版xp系统镜像怎么设置)
-
msdnitellyou又可以上了,那里有。 制作需要的软件 在开始进行制作之前,我们首先需要下载几个软件,启动光盘制作工具:EasyBoot,UltraISO以及用来对制作好的ISO镜像进行测...
- office2007密钥 office2016(office2007ultimate密钥)
-
word2016激活密钥有两种类型:永久激活码和KMS期限激活密钥。其中,永久激活密钥可以使用批量授权版永久激活密钥进行激活,如所示;而KMS期限激活密钥需要使用KMS客户端密钥进行激活,如所示。另外...
- windows10系统启动盘制作(windows10启动盘制作教程)
-
Windows10系统更改启动磁盘的方法如下1、按快捷键Win+R,调出命令窗口2、输入msconfig,点【确定】3、在系统配置中,选择【引导】菜单4、选择要默认启动的磁盘,点【设置为默认值】,...
- 方正电脑怎么重装系统
-
购买一张系统盘,然后启动电脑,将购买的系统盘插入电脑光驱中,等待光驱读取系统盘后,点击安装系统,即可自动安装,等待安装完毕,电脑会自动重启,重新启动后,电脑的系统就安装完毕,可以使用了一、准备需要的软...
-
- qq邮箱怎么写才正确
-
步骤/方式1一般默认的QQ邮箱格式是:QQ号码@qq.com,即QQ账号+@qq.com后缀步骤/方式2若要发送邮件,也要在对方的qq帐号末尾加上@qq.com1.每个人在注册QQ时都会有关联的一个邮箱,它的格式就是“QQ号码@qq.com...
-
2025-12-21 18:51 off999
-
- 电脑怎么看配置信息
-
要查看Windows电脑的配置信息,可以通过按下Win键+R,然后在弹出的运行对话框中输入“dxdiag”并按回车键打开DirectX诊断工具,可以查看有关处理器、内存、显卡等硬件信息。另外,还可以右键点击“此电脑”,选择“属性”来查看...
-
2025-12-21 18:03 off999
- mpeg是什么格式(mpeg是什么格式和mp4的区别)
-
是视频格式,是电脑视频文件的一种,相对其它视频文件格式而言,mpeg格式占的存储空间相对比较小,那么像素也就相对比较低,图像也没有其它格式那么高清,不过一般情况下已经满足正常的使用。好多视频文件都是采...
欢迎 你 发表评论:
- 一周热门
-
-
抖音上好看的小姐姐,Python给你都下载了
-
全网最简单易懂!495页Python漫画教程,高清PDF版免费下载
-
Python 3.14 的 UUIDv6/v7/v8 上新,别再用 uuid4 () 啦!
-
飞牛NAS部署TVGate Docker项目,实现内网一键转发、代理、jx
-
python入门到脱坑 输入与输出—str()函数
-
宝塔面板如何添加免费waf防火墙?(宝塔面板开启https)
-
Python三目运算基础与进阶_python三目运算符判断三个变量
-
(新版)Python 分布式爬虫与 JS 逆向进阶实战吾爱分享
-
失业程序员复习python笔记——条件与循环
-
系统u盘安装(win11系统u盘安装)
-
- 最近发表
- 标签列表
-
- python计时 (73)
- python安装路径 (56)
- python类型转换 (93)
- python进度条 (67)
- python吧 (67)
- python的for循环 (65)
- python格式化字符串 (61)
- python静态方法 (57)
- python列表切片 (59)
- python面向对象编程 (60)
- python 代码加密 (65)
- python串口编程 (77)
- python封装 (57)
- python写入txt (66)
- python读取文件夹下所有文件 (59)
- python操作mysql数据库 (66)
- python获取列表的长度 (64)
- python接口 (63)
- python调用函数 (57)
- python多态 (60)
- python匿名函数 (59)
- python打印九九乘法表 (65)
- python赋值 (62)
- python异常 (69)
- python元祖 (57)
