【使用 TensorFlow 2】03/3 创建自定义损失函数

一、说明

        TensorFlow 2发布已经接近5年时间,不仅继承了Keras快速上手和易于使用的特性,同时还扩展了原有Keras所不支持的分布式训练的特性。3大设计原则:简化概念,海纳百川,构建生态.这是本系列的第三部分,我们将创建代价函数并在 TensorFlow 2 中使用它们。
 

图 1:实际应用中的梯度下降算法

二、关于代价函数

        神经网络学习将训练数据中的一组输入映射到一组输出。它通过使用某种形式的优化算法来实现这一点,例如梯度下降、随机梯度下降、AdaGrad、AdaDelta 或一些最近的算法,例如 Adam、Nadam 或 RMSProp。梯度下降中的“梯度”指的是误差梯度。每次迭代后,网络都会将其预测输出与实际输出进行比较,然后计算“误差”。通常,对于神经网络,我们寻求最小化错误。因此,用于最小化误差的目标函数通常称为成本函数或损失函数,并且由“损失函数”计算的值简称为“损失”。各种问题中使用的典型损失函数 –

A。均方误差

b. 均方对数误差

C。二元交叉熵

d. 分类交叉熵

e. 稀疏分类交叉熵

        在Tensorflow中,这些损失函数已经包含在内,我们可以如下所示调用它们。

        1 损失函数作为字符串

model.compile(损失='binary_crossentropy',优化器='adam',指标=['准确性'])

        或者,

        2. 损失函数作为对象

从tensorflow.keras.losses导入mean_squared_error

model.compile(损失=mean_squared_error,优化器='sgd')

        将损失函数作为对象调用的优点是我们可以在损失函数旁边传递参数,例如阈值。

从tensorflow.keras.losses导入mean_squared_error

model.compile(损失=均方误差(参数=值),优化器='sgd')

三、使用函数创建自定义损失:

        为了使用函数创建损失,我们需要首先命名损失函数,它将接受两个参数,y_true(真实标签/输出)和y_pred(预测标签/输出)。

def loss_function(y_true, y_pred):

***一些计算***

回波损耗

四、创建均方根误差损失 (RMSE):

        损失函数名称 — my_rmse

        目的是返回目标 (y_true) 和预测 (y_pred) 之间的均方根误差。

        RMSE 公式:

  • 误差:真实标签和预测标签之间的差异。
  • sqr_error:误差的平方。
  • mean_sqr_error:误差平方的平均值
  • sqrt_mean_sqr_error:误差平方均值的平方根(均方根误差)。
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import backend as K
#defining the loss function
def my_rmse(y_true, y_pred):#difference between true label and predicted labelerror = y_true-y_pred    #square of the errorsqr_error = K.square(error)#mean of the square of the errormean_sqr_error = K.mean(sqr_error)#square root of the mean of the square of the errorsqrt_mean_sqr_error = K.sqrt(mean_sqr_error)#return the errorreturn sqrt_mean_sqr_error
#applying the loss function
model.compile (optimizer = 'sgd', loss = my_rmse)

五、创建 Huber 损失 

图 2:Huber 损失(绿色)和平方误差损失(蓝色)作为 y — f(x) 的函数

         Huber损失的公式:

        这里,

        δ是阈值,

        a是误差(我们将计算 a ,标签和预测之间的差异)

        所以,当|a| ≤δ,损失= 1/2*(a)²

        当|a|>δ 时,损失 = δ(|a| — (1/2)*δ)

        代码:

# creating the Conv-Batch Norm blockdef conv_bn(x, filters, kernel_size, strides=1):x = Conv2D(filters=filters, kernel_size = kernel_size, strides=strides, padding = 'same', use_bias = False)(x)x = BatchNormalization()(x)
return x

        解释:

        首先我们定义一个函数 - my huber loss,它接受 y_true 和 y_pred

        接下来我们设置阈值 = 1

        接下来我们计算误差 a = y_true-y_pred

        接下来我们检查误差的绝对值是否小于或等于阈值。is_small_error返回一个布尔值(True 或 False)。

        我们知道,当|a| ≤δ,loss = 1/2*(a)²,因此我们将small_error_loss计算为误差的平方除以2 

        否则,当|a| >δ,则损失等于 δ(|a| — (1/2)*δ)。我们在big_error_loss中计算这一点。

        最后,在return语句中,我们首先检查is_small_error是true还是false,如果是true,函数返回small_error_loss,否则返回big_error_loss。这是使用 tf.where 完成的。

        然后我们可以使用下面的代码编译模型,

model.compile(optimizer='sgd', loss=my_huber_loss)

在前面的代码中,我们始终使用阈值1。

但是,如果我们想要调整超参数(阈值)并在编译期间添加新的阈值,该怎么办?然后我们必须使用函数包装,即将损失函数包装在另一个外部函数周围。我们需要一个包装函数,因为默认情况下任何损失函数只能接受 y_true 和 y_pred 值,并且我们不能向原始损失函数添加任何其他参数。

5.1 使用包装函数的 Huber 损失

        包装函数代码如下所示:

import tensorflow as tf
#wrapper function which accepts the threshold parameter
def my_huber_loss_with_threshold(threshold):def my_huber_loss(y_true, y_pred):   error = y_true - y_pred     is_small_error = tf.abs(error) <= threshold     small_error_loss = tf.square(error) / 2     big_error_loss = threshold * (tf.abs(error) - (0.5 * threshold))return tf.where(is_small_error, small_error_loss, big_error_loss)return my_huber_loss

在这种情况下,阈值不是硬编码的。相反,我们可以在模型编译期间通过阈值。

model.compile(optimizer='sgd', loss=my_huber_loss_with_threshold(threshold=1.5))

5.2 使用类的 Huber 损失 (OOP)

import tensorflow as tf
from tensorflow.keras.losses import Lossclass MyHuberLoss(Loss): #inherit parent class#class attributethreshold = 1#initialize instance attributesdef __init__(self, threshold):super().__init__()self.threshold = threshold#compute lossdef call(self, y_true, y_pred):error = y_true - y_predis_small_error = tf.abs(error) <= self.thresholdsmall_error_loss = tf.square(error) / 2big_error_loss = self.threshold * (tf.abs(error) - (0.5 * self.threshold))return tf.where(is_small_error, small_error_loss, big_error_loss)

        MyHuberLoss是类名。在类名之后,我们从tensorflow.keras.losses继承父类'Loss'。所以MyHuberLoss继承为Loss。这允许我们使用 MyHuberLoss 作为损失函数。

        __init__从类中初始化对象。

        从类实例化对象时执行的调用函数

        init 函数获取阈值,call 函数获取我们之前出售的 y_true 和 y_pred 参数。因此,我们将阈值声明为类变量,这允许我们给它一个初始值。

        在 __init__ 函数中,我们将阈值设置为 self.threshold。

        在调用函数中,所有阈值类变量将由 self.threshold 引用。

        以下是我们如何在 model.compile 中使用这个损失函数。

model.compile(optimizer='sgd', loss=MyHuberLoss(threshold=1.9))

六、创建对比损失(用于暹罗网络):

        连体网络比较两个图像是否相似。对比损失是暹罗网络中使用的损失函数。

        在上面的公式中,

        Y_true 是图像相似度细节的张量。如果图像相似,它们就是 1,如果不相似,它们就是 0。

        D 是图像对之间的欧几里德距离的张量。

        边距是一个常量,我们可以用它来强制它们之间的最小距离,以便将它们视为相似或不同。

        如果Y_true =1,则方程的第一部分变为 D²,第二部分变为零。因此,当 Y_true 接近 1 时,D² 项具有更大的权重。

        如果Y_true = 0,则方程的第一部分变为零,第二部分产生一些结果。这为最大项赋予了更大的权重,而为 D 平方项赋予了更少的权重,因此最大项在损失的计算中占主导地位。

使用包装函数的对比损失

def contrastive_loss_with_margin(margin):def contrastive_loss(y_true, y_pred):square_pred = K.square(y_pred)margin_square = K.square(K.maximum(margin - y_pred, 0))return K.mean(y_true * square_pred + (1 - y_true) * margin_square)return contrastive_loss

七、结论 

        Tensorflow 中不可用的任何损失函数都可以使用函数、包装函数或以类似的方式使用类来创建。阿琼·萨卡

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/157512.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

区块链加密虚拟货币交易平台安全解决方案

区块链机密货币交易锁遭入侵&#xff0c;安全存在隐患。使用泰雷兹Protect server HSM加密机&#xff0c;多方位保护您的数据&#xff0c;并通过集中化管理&#xff0c;安全的存储密钥。 引文部分&#xff1a; 损失7000万美元!黑客入侵香港区块链加密货币交易所 2023年9月&…

如何在Ubuntu 20.04.6 LTS系统上运行Playwright自动化测试

写在前面 这里以 Ubuntu 20.04.6 LTS为例。示例代码&#xff1a;自动化测试代码。 如果过程中遇到其他非文本中提到的错误&#xff0c;可以使用搜索引擎搜索错误&#xff0c;找出解决方案&#xff0c;再逐步往下进行。 一、 环境准备 1.1 安装python3 1.1.1 使用APT安装Py…

【Hello Algorithm】暴力递归到动态规划(二)

暴力递归到动态规划&#xff08;二&#xff09; 背包问题递归版本动态规划 数字字符串改字母字符串递归版本动态规划 字符串贴纸递归版本动态规划 **特别需要注意的是 我们使用数组之前一定要进行初始化 不然很有可能会遇到一些意想不到的错误 比如说在Linux平台上 new出来的in…

记一次生产大对象及GC时长优化经验

最近在做一次系统整体优化,发现系统存在GC时长过长及JVM内存溢出的问题,记录一下优化的过程 面试的时候我们都被问过如何处理生产问题&#xff0c;尤其是线上oom或者GC调优的问题更是必问&#xff0c;所以到底应该如何发现解决这些问题呢&#xff0c;用真实的场景实操&#xff…

2015架构案例(五十一)

第5题 【说明】某信息技术公司计划开发一套在线投票系统&#xff0c;用于为市场调研、信息调查和销售反馈等业务提供服务。该系统计划通过大量宣传和奖品鼓励的方式快速积累用户&#xff0c;当用户规模扩大到一定程度时&#xff0c;开始联系相关企业提供信息服务&#xff0c;并…

批量执行insert into 的脚本报2006 - MySQL server has gone away

数据库执行批量数据导入是报“2006 - MySQL server has gone away”错误&#xff0c;脚本并没有问题&#xff0c;只是insert into 的批量操作语句过长导致。 解决办法&#xff1a; Navicat ->工具 ->服务器监控->mysql ——》变量 修改max_allowed_packet大小为512…

TCP/IP(七)TCP的连接管理(四)全连接

一 全连接队列 nginx listen 参数backlog的意义 nginx配置文件中listen后面的backlog配置 ① TCP全连接队列概念 全连接队列: 也称 accept 队列 ② 查看应用程序的 TCP 全连接队列大小 实验1&#xff1a; ss 命令查看 LISTEN状态下 Recv-Q/Send-Q 含义附加&#xff1a;…

【Java学习之道】日期与时间处理类

引言 在前面的章节中&#xff0c;我们介绍了Java语言的基础知识和核心技能&#xff0c;现在我们将进一步探讨Java中的常用类库和工具。这些工具和类库将帮助我们更高效地进行Java程序开发。在本节中&#xff0c;我们将一起学习日期与时间处理类的使用。 一、为什么需要日期和…

vsCode 忽略 文件上传

1 无 .gitignore 文件时&#xff0c;在项目文件右键&#xff0c;Git Bash 进入命令行 输入 touch .gitignore 生成gitignore文件 2 、在文件.gitignore里输入 node_modules/ dist/ 来自于&#xff1a;vscode git提交代码忽略node_modules_老妖zZ的博客-CSDN博客

k8s - Flannel

1.Flannel概念剖析 Flannel是 CoreOS 团队针对 Kubernetes 设计的一个覆盖网络&#xff08;Overlay Network&#xff09;工具&#xff0c;其目的在于帮助每一个使用 Kuberentes 的 CoreOS 主机拥有一个完整的子网。这次的分享内容将从Flannel的介绍、工作原理及安装和配置三方…

④. GPT错误:导入import pandas as pd库,存储输入路径图片信息存储错误

꧂ 问题最初꧁ 用 import pandas as pd 可是你没有打印各种信息input输入图片路径 print图片尺寸 大小 长宽高 有颜色占比>0.001的按照大小排序将打印信息存储excel表格文件名 表格路径 图片大小 尺寸 颜色类型 占比信息input输入的是文件就处理文件 是文件夹&#x1f4c…

44.ES

一、ES。 &#xff08;1&#xff09;es概念。 &#xff08;1.1&#xff09;什么是es。 &#xff08;1.2&#xff09;es的发展。 es是基于lucene写的。 &#xff08;1.3&#xff09;总结。 es是基于lucene写的。 &#xff08;2&#xff09;倒排索引。 &#xff08;3&#xf…

flutter 开发中的问题与技巧

一、概述 刚开始上手 flutter 开发的时候&#xff0c;总会遇到这样那样的小问题&#xff0c;而官方文档又没有明确说明不能这样使用&#xff0c;本文总结了一些开发中经常会遇到的一些问题和一些开发小技巧。 二、常见问题 1、Expanded 组件只能在 Row、Column、Flex 中使用 C…

GEE:基于GLDAS数据集分析土壤湿度的时间序列变化

作者:CSDN @ _养乐多_ 本篇博客将介绍如何使用Google Earth Engine(GEE)进行土壤湿度数据的分析。我们将使用NASA GLDAS(Global Land Data Assimilation System)数据集,其中包括了关于土壤湿度的信息。通过该数据集,我们将了解土壤湿度在特定区域和时间段内的变化,并生…

springboot vue 部署至Rocky(Centos)并自启,本文部署是若依应用

概述 1、安装nohup&#xff08;后台进程运行java&#xff09; 2、安装中文字体&#xff08;防止中文乱码&#xff09; 3、安装chrony&#xff08;保证分布式部署时间的一致性&#xff09; 5、安装mysql数据&#xff0c;迁移目录&#xff0c;并授权自启动&#xff1b; 6、安…

SpringBoot注解篇之@Validated

目录 前言Validated作用NotNull与NotBlank区别总结 前言 大家好&#xff0c;我是AK&#xff0c;在做新项目顺便整理SpringBoot相关内容&#xff0c;这里主要介绍下Validated注解的应用&#xff0c;减少核心业务逻辑中一些参数判断的代码。 Validated作用 Validated 是 Spring…

Linux友人帐之系统管理与虚拟机相关

一、虚拟机相关操作 1.1虚拟机克隆 虚拟机克隆是指将一个已经安装好的虚拟机复制出一个或多个完全相同的副本&#xff0c;包括虚拟机的配置、操作系统、应用程序等&#xff0c;从而节省安装和配置的时间和资源。 虚拟机克隆的主要用途有&#xff1a; 创建多个相同或相似的虚拟…

论文导读|八月下旬Operations Research文章精选:定价问题专题

编者按&#xff1a; ​ ​在“ Operations Research论文精选”中&#xff0c;我们有主题、有针对性地选择了Operations Research中一些有趣的文章&#xff0c;不仅对文章的内容进行了概括与点评&#xff0c;而且也对文章的结构进行了梳理&#xff0c;旨在激发广大读者的阅读兴…

win10搭建gtest测试环境+vs2019

首先是下载gtest&#xff0c;这个我已经放在了博客上方资源绑定处&#xff0c;这个适用于win10vs版本&#xff0c;关于liunx版本的不能用这个。 或者百度网盘链接&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/15m62KAJ29vNe1mrmAcmehA 提取码&#xff1a;vfxz 下…

asp.net会议预约管理系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net 会议预约管理系统 是一套完善的web设计管理系统&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为vs2010&#xff0c;数据库为sqlserver2008&#xff0c;使用c#语 言开发 asp.net 会议预约管理系统 二、…