百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术资源 > 正文

独家 | 使用TensorFlow 2创建自定义损失函数

off999 2024-09-14 07:11 28 浏览 0 评论

本文带你学习使用Python中的wrapper函数和OOP来编写自定义损失函数。


图1:梯度下降算法

神经网络利用训练数据,将一组输入映射成一组输出,它通过使用某种形式的优化算法,如梯度下降、随机梯度下降、AdaGrad、AdaDelta等等来实现,其中最新的算法包括Adam、Nadam或RMSProp。梯度下降中的“梯度”是指误差梯度。每次迭代之后,网络将其预测输出与实际输出进行比较,然后计算出“误差”。

通常,对于神经网络,寻求的是将误差最小化。将误差最小化的目标函数通常称之为成本函数或损失函数,由“损失函数”计算出的值称为“损失”。在各种问题中使用的典型损失函数有:


均方误差;

均方对数误差;

二元交叉熵;

分类交叉熵;

稀疏分类交叉熵。

Tensorflow已经包含了上述损失函数,直接调用它们即可,如下所示:

1. 将损失函数当作字符串进行调用

model.compile (loss = ‘binary_crossentropy’,optimizer = ‘adam’, metrics = [‘accuracy’])

2. 将损失函数当作对象进行调用

from tensorflow.keras.losses importmean_squared_errormodel.compile(loss = mean_squared_error,optimizer=’sgd’)

将损失函数当作对象进行调用的优点是可以在损失函数中传递阈值等参数。

from tensorflow.keras.losses import mean_squared_errormodel.compile (loss=mean_squared_error(param=value),optimizer = ‘sgd’)

利用现有函数创建自定义损失函数:

利用现有函数创建损失函数,首先需要定义损失函数,它将接受两个参数,y_true(真实标签/输出)和y_pred(预测标签/输出)。

def loss_function(y_true, y_pred):***some calculation***return loss

创建均方误差损失函数 (RMSE):

定义损失函数名称-my_rmse。目的是返回目标(y_true)与预测(y_pred)之间的均方误差。

RMSE的公式为:

误差:真实标签与预测标签之间的差异。

sqr_error:误差的平方。

mean_sqr_error:误差平方的均值。

sqrt_mean_sqr_error:误差平方均值的平方根(均方根误差)。

创建Huber损失函数:

图2:Huber损失函数(绿色)和平方误差损失函数(蓝色)(来源:Qwertyus— Own work,CCBY-SA4.0,https://commons.wikimedia.org/w/index.php?curid=34836380)

Huber损失函数的计算公式:

在此处,δ是阈值,a是误差(将计算出a,即实际标签和预测标签之间的差异)。

当|a|≤δ时,loss = 1/2*(a)2

当 |a|>δ时,loss = δ(|a|—(1/2)*δ)

源代码:

详细说明:

首先,定义一个函数—— my huber loss,它需要两个参数:y_true和y_pred,

设置阈值threshold = 1。

计算误差error a = y_true-y_pred。接下来,检查误差的绝对值是否小于或等于阈值,is_small_error返回一个布尔值(真或假)。

当|a|≤δ时,loss= 1/2*(a)2,计算small_error_loss, 误差的平方除以2。否则,当|a| >δ时,则损失等于δ(|a|-(1/2)*δ),用big_error_loss来计算这个值。

最后,在返回语句中,首先检查is_small_error是真还是假,如果它为真,函数返回small_error_loss,否则返回big_error_loss,使用tf.where来实现。

可以使用下述代码来编译模型:

在上述代码中,将阈值设为1。

如果需要调整超参数(阈值),并在编译过程中加入一个新的阈值的话,必须使用wrapper函数进行封装,也就是说,将损失函数封装成另一个外部函数。在这里需要用到封装函数(wrapper function),因为损失函数在默认情况下只能接受y_true和y_pred值,而且不能向原始损失函数添加任何其他参数。


使用封装后的Huber损失函数

封装函数的源代码:

此时,阈值不是硬编码,可以在模型编译过程中传递该阈值。

使用类实现Huber损失函数(OOP)

其中,MyHuberLoss是类名称,随后从tensorflow.keras.losses继承父类“Loss”, MyHuberLoss继承了Loss类,之后可以将MyHuberLoss当作损失函数来使用。

__init__ 初始化该类中的对象。执行类实例化对象时调用函数,init函数返回阈值,调用函数得到y_true和y_pred参数,将阈值声明为一个类变量,可以给它赋一个初始值。

在__init__函数中,将阈值设置为self.threshold。在调用函数中,self.threshold引用所有的阈值类变量。在model.compile中使用这个损失函数:

创建对比性损失(用于Siamese网络):

Siamese网络可以用来比较两幅图像是否相似,Siamese网络使用的损失函数为对比性损失。

在上文的公式中,Y_true是关于图像相似性细节的张量,如果图像相似,则为1,如果图像不相似,则为0。

D是图像对之间的欧氏距离的张量。边际为一个常量,用它来设置将图像区别为相似或不同的最小距离。如果为Y_true=1,则方程的第一部分为D2,第二部分为0,所以,当Y_true接近1时,D2的权重则更重。

如果Y_true=0,则方程的第一部分变为0,第二部分会产生一些结果,这给了最大项更多的权重,给了D平方项更少的权重,此时,最大项在损失计算中占了优势。

使用封装器函数实现对比损失函数:

结论

在Tensorflow中没有的损失函数都可以利用函数、包装函数或类似的类来创建。


原文标题:

Creating custom Loss functionsusing TensorFlow 2

原文链接:

https://towardsdatascience.com/creating-custom-loss-functions-using-tensorflow-2-96c123d5ce6c

编辑:黄继彦

校对:林亦霖

相关推荐

Python函数参数和返回值类型:让你的代码更清晰、更健壮

在Python开发中,你是否遇到过这些抓狂时刻?同事写的函数参数类型全靠猜调试两小时发现传了字符串给数值计算函数重构代码时不知道函数返回的是列表还是字典今天教你两招,彻底解决类型混乱问题!让你的...

有公司内部竟然禁用了python开发,软件开发何去何从?

今天有网友在某社交平台发文:有公司内部竟然禁止了python开发!帖子没几行,评论却炸锅了。有的说“太正常,Python本就不适合做大项目”,还有的反驳“飞书全员用Python”。暂且不说这家公司...

写 Python 七年才发现的七件事:真正提高生产力的脚本思路

如果你已经用Python写了不少脚本,却总觉得代码只是“能跑”,这篇文章或许会刷新你对这门语言的认知。以下七个思路全部来自一线实战,没有花哨的概念,只有可落地的工具与习惯。它们曾帮我省下大量无意义...

用Python写一个A*搜索算法含注释说明

大家好!我是幻化意识流。今天我们用Python写一个A*搜索算法的代码,我做了注释说明,欢迎大家一起学习:importheapq#定义搜索节点类,包括当前状态、从初始状态到该状态的代价g、从该状态...

使用python制作一个贪吃蛇游戏,并为每一句添加注释方便学习

今天来设计一个贪吃蛇的经典小游戏。先介绍下核心代码功能(源代码请往最后面拉):游戏功能:-四个难度等级:简单(8FPS)、中等(12FPS)、困难(18FPS)、专家(25FPS)-美...

Python 之父 Guido van Rossum 宣布退休

Python之父GuidovanRossum在推特公布了自己从Dropbox公司离职的消息,并表示已经退休。他还提到自己在Dropbox担任工程师期间学到了很多东西——Python的类型注解(T...

4 个早该掌握的 Python 类型注解技巧

在Python的开发过程中,类型注解常常被忽视。但当面对一段缺乏类型提示、逻辑复杂的代码时,理解和维护成本会迅速上升,极易陷入“阅读地狱”。本文整理了4个关于Python类型注解的重要技巧...

让你的Python代码更易读:7个提升函数可读性的实用技巧

如果你正在阅读这篇文章,很可能你已经用Python编程有一段时间了。今天,让我们聊聊可以提升你编程水平的一件事:编写易读的函数。请想一想:我们花在阅读代码上的时间大约是写代码的10倍。所以,每当你创建...

Python异常模块和包

异常当检测到一个错误时,Python解释器就无法继续执行了,反而出现了一些错误的提示,这就是所谓的“异常”,也就是我们常说的BUG例如:以`r`方式打开一个不存在的文件。f=open('...

别再被 return 坑了!一文吃透 Python return 语句常见错误与调试方法

Pythonreturn语句常见错误与调试方法(结构化详解)一.语法错误:遗漏return或返回值类型错误错误场景pythondefadd(a,b):print(a+b)...

Python数据校验不再难:Pydantic库的工程化实践指南

在FastAPI框架横扫Python后端开发领域的今天,其默认集成的Pydantic库正成为处理数据验证的黄金标准。这个看似简单的库究竟隐藏着哪些让开发者爱不释手的能力?本文将通过真实项目案例,带您解...

python防诈骗的脚本带注释信息

以下是一个简单但功能完整的防诈骗脚本,包含URL检测、文本分析和风险评估功能。代码结构清晰,带有详细注释,适合作为个人或家庭防诈骗工具使用。这个脚本具有以下功能:文本诈骗风险分析:检测常见诈骗关键...

Python判断语句

布尔类型和比较运算符布尔类型的定义:布尔类型只有两个值:True和False可以通过定义变量存储布尔类型数据:变量名称=布尔类型值(True/False)布尔类型不仅可以自行定义,同时也可通过...

使用python编写俄罗斯方块小游戏并为每一句添加注释,方便学习

先看下学习指导#俄罗斯方块游戏开发-Python学习指导##项目概述这个俄罗斯方块游戏是一个完整的Python项目,涵盖了以下重要的编程概念:-面向对象编程(OOP)-游戏开发基础-数据...

Python十大技巧:不掌握这些,你可能一直在做无用功!

在编程的世界里,掌握一门语言只是起点,如何写出优雅、高效的代码才是真功夫。Python作为最受欢迎的编程语言之一,拥有简洁明了的语法,但要想真正精通这门语言,还需要掌握一些实用的高级技巧。一、列表推导...

取消回复欢迎 发表评论: