独家 | 使用TensorFlow 2创建自定义损失函数
off999 2024-09-14 07:11 42 浏览 0 评论
本文带你学习使用Python中的wrapper函数和OOP来编写自定义损失函数。
图1:梯度下降算法
神经网络利用训练数据,将一组输入映射成一组输出,它通过使用某种形式的优化算法,如梯度下降、随机梯度下降、AdaGrad、AdaDelta等等来实现,其中最新的算法包括Adam、Nadam或RMSProp。梯度下降中的“梯度”是指误差梯度。每次迭代之后,网络将其预测输出与实际输出进行比较,然后计算出“误差”。
通常,对于神经网络,寻求的是将误差最小化。将误差最小化的目标函数通常称之为成本函数或损失函数,由“损失函数”计算出的值称为“损失”。在各种问题中使用的典型损失函数有:
均方误差;
均方对数误差;
二元交叉熵;
分类交叉熵;
稀疏分类交叉熵。
Tensorflow已经包含了上述损失函数,直接调用它们即可,如下所示:
1. 将损失函数当作字符串进行调用
model.compile (loss = ‘binary_crossentropy’,optimizer = ‘adam’, metrics = [‘accuracy’])
2. 将损失函数当作对象进行调用
from tensorflow.keras.losses importmean_squared_errormodel.compile(loss = mean_squared_error,optimizer=’sgd’)
将损失函数当作对象进行调用的优点是可以在损失函数中传递阈值等参数。
from tensorflow.keras.losses import mean_squared_errormodel.compile (loss=mean_squared_error(param=value),optimizer = ‘sgd’)
利用现有函数创建自定义损失函数:
利用现有函数创建损失函数,首先需要定义损失函数,它将接受两个参数,y_true(真实标签/输出)和y_pred(预测标签/输出)。
def loss_function(y_true, y_pred):***some calculation***return loss
创建均方误差损失函数 (RMSE):
定义损失函数名称-my_rmse。目的是返回目标(y_true)与预测(y_pred)之间的均方误差。
RMSE的公式为:
误差:真实标签与预测标签之间的差异。
sqr_error:误差的平方。
mean_sqr_error:误差平方的均值。
sqrt_mean_sqr_error:误差平方均值的平方根(均方根误差)。
创建Huber损失函数:
图2:Huber损失函数(绿色)和平方误差损失函数(蓝色)(来源:Qwertyus— Own work,CCBY-SA4.0,https://commons.wikimedia.org/w/index.php?curid=34836380)
Huber损失函数的计算公式:
在此处,δ是阈值,a是误差(将计算出a,即实际标签和预测标签之间的差异)。
当|a|≤δ时,loss = 1/2*(a)2
当 |a|>δ时,loss = δ(|a|—(1/2)*δ)
源代码:
详细说明:
首先,定义一个函数—— my huber loss,它需要两个参数:y_true和y_pred,
设置阈值threshold = 1。
计算误差error a = y_true-y_pred。接下来,检查误差的绝对值是否小于或等于阈值,is_small_error返回一个布尔值(真或假)。
当|a|≤δ时,loss= 1/2*(a)2,计算small_error_loss, 误差的平方除以2。否则,当|a| >δ时,则损失等于δ(|a|-(1/2)*δ),用big_error_loss来计算这个值。
最后,在返回语句中,首先检查is_small_error是真还是假,如果它为真,函数返回small_error_loss,否则返回big_error_loss,使用tf.where来实现。
可以使用下述代码来编译模型:
在上述代码中,将阈值设为1。
如果需要调整超参数(阈值),并在编译过程中加入一个新的阈值的话,必须使用wrapper函数进行封装,也就是说,将损失函数封装成另一个外部函数。在这里需要用到封装函数(wrapper function),因为损失函数在默认情况下只能接受y_true和y_pred值,而且不能向原始损失函数添加任何其他参数。
使用封装后的Huber损失函数
封装函数的源代码:
此时,阈值不是硬编码,可以在模型编译过程中传递该阈值。
使用类实现Huber损失函数(OOP)
其中,MyHuberLoss是类名称,随后从tensorflow.keras.losses继承父类“Loss”, MyHuberLoss继承了Loss类,之后可以将MyHuberLoss当作损失函数来使用。
__init__ 初始化该类中的对象。执行类实例化对象时调用函数,init函数返回阈值,调用函数得到y_true和y_pred参数,将阈值声明为一个类变量,可以给它赋一个初始值。
在__init__函数中,将阈值设置为self.threshold。在调用函数中,self.threshold引用所有的阈值类变量。在model.compile中使用这个损失函数:
创建对比性损失(用于Siamese网络):
Siamese网络可以用来比较两幅图像是否相似,Siamese网络使用的损失函数为对比性损失。
在上文的公式中,Y_true是关于图像相似性细节的张量,如果图像相似,则为1,如果图像不相似,则为0。
D是图像对之间的欧氏距离的张量。边际为一个常量,用它来设置将图像区别为相似或不同的最小距离。如果为Y_true=1,则方程的第一部分为D2,第二部分为0,所以,当Y_true接近1时,D2的权重则更重。
如果Y_true=0,则方程的第一部分变为0,第二部分会产生一些结果,这给了最大项更多的权重,给了D平方项更少的权重,此时,最大项在损失计算中占了优势。
使用封装器函数实现对比损失函数:
结论
在Tensorflow中没有的损失函数都可以利用函数、包装函数或类似的类来创建。
原文标题:
Creating custom Loss functionsusing TensorFlow 2
原文链接:
https://towardsdatascience.com/creating-custom-loss-functions-using-tensorflow-2-96c123d5ce6c
编辑:黄继彦
校对:林亦霖
相关推荐
- 系统大全网站(系统大全网站推荐)
-
下载时发生错误可能是以下原因:1.你的网速过慢,网页代码没有完全下载就运行了,导致不完整,当然就错误了。请刷新。2.网页设计错误,导致部分代码不能执行。请下载最新的遨游浏览器。3.你的浏览器不兼容导致...
- win10官方启动盘(win10官方启动盘怎么用)
-
1、在开始菜单搜索“设置”,打开“设置”;2、点击“更新与安全”,在左侧菜单栏点击“恢复”;3、点击“启动项”,在弹出的窗口中会显示当前可以启动的项目,点击“编辑”;4、在打开的“编辑启动项”窗口中,...
- win10系统安装不了(win10 安装不了)
-
电脑装不上win10系统可能是因为以下几个原因导致的原因一:win10安装文件不对我们在安装win10之前,要确保下载到安装包真实可用的,否则安装肯定会有问题,建议下载安全可靠的安装包!原因二:系统文...
- 国内dns哪个最快(dns开启好还是关闭好)
-
移动dns设置首选114.114.114.114,它又好又快。首选DNS和备用DNS都是一种域名系统,这两种域名系统有着先后之分,如果在首选DNS正常的情况下,就用首选DNS地址。当首选DNS服务器出...
- winxp安装盘(winxp系统安装)
-
xp系统安装步骤如下1、将下载的xp系统iso压缩包文件下载到C盘之外的分区,比如下载到D盘,右键使用WinRAR等工具解压到当前文件夹或指定文件夹,不能解压到C盘和桌面,否则无法安装;?2、解压之后...
- 现在的win11稳定了吗(win11稳定嘛)
-
windows10更稳定,由于win11刚刚推出没多久,稳定差不够好,兼容性也有待提升,无论是应用还是游戏都会遇到不明程度的问题,因此,在日常的使用过程中,我们还是应当以稳定性为优先,选择win10是...
- xp安装包下载到手机(xp系统安装包)
-
手机是基于ARM架构的处理器,而WindowsXP是基于x86架构的操作系统,因此无法直接在手机上安装WindowsXP。除非您的手机是使用Intel处理器,但这种情况非常罕见。如果您需要在手机上...
- 如何查看硬盘序列号(windows如何查看硬盘序列号)
-
1.打开开始菜单栏,输入【cmd】点击【确定】;2.在命令窗口依次输入【diskpart】-【listdisk】-【selectdisk0】;3.选好要查看的硬盘后,接着输入【detaildi...
- 虚拟机安装win7教程(虚拟机安装win7教程图解)
-
1.首先,下载并安装虚拟机软件,如VMwareWorkstation、VirtualBox等。2.打开虚拟机软件,创建一个新的虚拟机。3.在创建虚拟机的过程中,选择安装Windows7专业版的IS...
- 系统脱敏法的操作程序如何
-
系统脱敏疗法(systematicdesensitization)又称交互抑制法,是由美国学者沃尔普创立和发展的。这种方法主要是诱导求治者缓慢地暴露出导致神经症焦虑、恐惧的情境,并通过心理的放松状态...
- 闪迪u盘低级格式化工具(闪迪u盘格式化分配单元大小)
-
闪迪U盘格式化后速度变慢的可能原因及解决方法如下:文件系统问题:格式化时选择的文件系统类型可能会影响U盘的性能。常见的文件系统类型包括FAT32、NTFS和exFAT等。如果文件系统类型不合适,可能会...
- psd文件下载(psd格式下载网站)
-
1、在photoshop中,不能通过置入的方法来加载PSD文件,因为,通过置入的方法加载PSD文件,它是以合并图层的方法把PSD文件加入,这样,就失去了PSD文件的所有图层信息。 2、在文档中想...
- 宏碁官网下载win7系统(宏碁官方系统)
-
宏基笔记本win8系统换成win7步骤:1、更改bios设置,关闭“SecureBoot”功能,启用传统的“LegacyBoot”。2、制作u启动U盘启动盘,下载win7系统安装包3、设置U盘启动...
- 如何重装系统win7旗舰版32位
-
首先下载制作一个带系统的启动u盘,然后按以下步骤安装:1、首先关闭电脑上面的杀毒软件,2、进入bios选择u盘启动。3、插入启动u盘重新启动电脑4、进入pe系统镜像环节,选择要安装的系统(32位),然...
- 应用程序发生异常0xe0000008
-
先查看一下对应的软件是不是出现了损坏,也可以重装此软件。我们还可以尝试通过修改注册表来解决。按Win+R(或者在开始菜单搜索框输入“运行”)打开运行,然后输入“regedit”回车,打开注册表恢复原来...
欢迎 你 发表评论:
- 一周热门
-
-
抖音上好看的小姐姐,Python给你都下载了
-
全网最简单易懂!495页Python漫画教程,高清PDF版免费下载
-
Python 3.14 的 UUIDv6/v7/v8 上新,别再用 uuid4 () 啦!
-
飞牛NAS部署TVGate Docker项目,实现内网一键转发、代理、jx
-
python入门到脱坑 输入与输出—str()函数
-
宝塔面板如何添加免费waf防火墙?(宝塔面板开启https)
-
Python三目运算基础与进阶_python三目运算符判断三个变量
-
(新版)Python 分布式爬虫与 JS 逆向进阶实战吾爱分享
-
失业程序员复习python笔记——条件与循环
-
系统u盘安装(win11系统u盘安装)
-
- 最近发表
- 标签列表
-
- python计时 (73)
- python安装路径 (56)
- python类型转换 (93)
- python进度条 (67)
- python吧 (67)
- python的for循环 (65)
- python格式化字符串 (61)
- python静态方法 (57)
- python列表切片 (59)
- python面向对象编程 (60)
- python 代码加密 (65)
- python串口编程 (77)
- python封装 (57)
- python写入txt (66)
- python读取文件夹下所有文件 (59)
- python操作mysql数据库 (66)
- python获取列表的长度 (64)
- python接口 (63)
- python调用函数 (57)
- python多态 (60)
- python匿名函数 (59)
- python打印九九乘法表 (65)
- python赋值 (62)
- python异常 (69)
- python元祖 (57)
