【使用 TensorFlow 2】02/3 使用 Lambda 层创建自定义激活函数

一、说明

        TensorFlow 2发布已经接近2年时间,不仅继承了Keras快速上手和易于使用的特性,同时还扩展了原有Keras所不支持的分布式训练的特性。3大设计原则:简化概念,海纳百川,构建生态.这是本系列的第三部分,我们将创建激活层并在 TensorFlow 2 中训练它们。
        之前我们已经了解了如何创建自定义损失函数 -使用 TensorFlow 2 创建自定义损失函数

自定义ReLU函数(来源:作者创建的图片)

二、介绍

        在本文中,我们将了解如何创建自定义激活函数。虽然 TensorFlow 已经包含一堆内置激活函数,但有多种方法可以创建您自己的自定义激活函数或编辑现有激活函数。

        ReLU(修正线性单元)仍然是任何神经网络架构的隐藏层中最常用的激活函数。ReLU 也可以表示为函数 f(x),其中,

        f(x) = 0, 当 x < 0 时,

        并且,当 x ≥ 0 时,f(x) = x。

        因此,该函数仅考虑正部分,并写为:

        f(x) = 最大值(0,x)

        或在代码表示中,

if input > 0:return input
else:return 0

        但这个ReLU函数是预定义的。如果我们想自定义此函数或创建我们自己的 ReLU 激活该怎么办?在 TensorFlow 中有一种非常简单的方法可以做到这一点——我们只需使用Lambda 层

ReLU 和 GeLU 激活函数

如何使用 lambda 层?

tf.keras.layers.Lambda(lambda x: tf.abs(x))

Lambda 只是可以在 TensorFlow 中直接调用的另一层。在 lambda 层中,首先指定参数。在上面的代码片段中,该值为“x”(lambda x)。在本例中,我们想要求 x 的绝对值,因此我们使用 tf.abs(x)。因此,如果 x 的值为 -1,则该 lambda 层会将 x 的值更改为 1。

如何使用 lambda 层创建自定义 ReLU?

def custom_relu(x):return K.maximum(0.0,x)
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(128,128)),tf.keras.layers.Dense(512),tf.keras.layers.Lambda(custom_relu),tf.keras.layers.Dense(5, activation = 'softmax')
])

上面的代码片段展示了如何在 TensorFlow 模型中实现自定义 ReLU。我们创建一个函数 custom_relu 并返回 0 或 x 的最大值(与 ReLU 函数相同)。

在下面的顺序模型中,在 Dense 层之后,我们创建一个 Lambda 层并将其传递到自定义激活函数中。但这段代码仍然没有做任何与 ReLU 激活函数不同的事情。

当我们开始研究自定义函数的返回值时,乐趣就开始了。假设我们取 0.5 和 x 中的最大值,而不是 0 和 x。我们已经有了自己定制的 ReLU。然后可以根据需要更改这些值。

def custom_relu(x):
返回 K.maximum(0.5,x)

def custom_relu(x):return K.maximum(0.5,x)
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(128,128)),tf.keras.layers.Dense(512),tf.keras.layers.Lambda(custom_relu),tf.keras.layers.Dense(5, activation = 'softmax')
])

在 mnist 数据集上使用 lambda 激活的示例

#using absolute value (Lambda layer example 1)
import tensorflow as tf
from tensorflow.keras import backend as K
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(128),tf.keras.layers.Lambda(lambda x: tf.abs(x)), tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

将 ReLU 激活替换为 mnist 数据集上的绝对值,测试精度为 97.384%。

#using custom ReLU activation (Lambda layer example 2)
import tensorflow as tf
from tensorflow.keras import backend as K
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
def my_relu(x):return K.maximum(-0.1, x)
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(128),tf.keras.layers.Lambda(my_relu), tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

        将 ReLU 激活替换为自定义 ReLU 激活,在 mnist 数据集上取最大值 -0.1 或 x,测试精度为 97.778%。

三、结论

        尽管 lambda 层使用起来非常简单,但它们有很多限制。在下一篇文章中,我将介绍如何在 TensorFlow 中创建可训练的完全自定义层。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/157580.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JMeter性能测试,完整入门篇

1. Jmeter简介 Apache JMeter是一款纯java编写负载功能测试和性能测试开源工具软件。相比Loadrunner而言&#xff0c;JMeter小巧轻便且免费&#xff0c;逐渐成为了主流的性能测试工具&#xff0c;是每个测试人员都必须要掌握的工具之一。 本文为JMeter性能测试完整入门篇&…

Web自动化测试入门 : 前端页面的组成分析详解

目前常见的前端页面是由HTMLcssJavaScript组成。 一、HTML&#xff1a; 作用&#xff1a;定义页面呈现的内容 HTML 是用来描述网页的一种语言。 HTML 指的是超文本标记语言 (Hyper Text Markup Language)HTML 不是一种编程语言&#xff0c;而是一种标记语言 (markup langua…

并发编程——1.java内存图及相关内容

这篇文章&#xff0c;我们来讲一下java的内存图及并发编程的预备内容。 首先&#xff0c;我们来看一下下面的这两段代码&#xff1a; 下面&#xff0c;我们给出上面这两段代码在运行时的内存结构图&#xff0c;如下图所示&#xff1a; 下面&#xff0c;我们来具体的讲解一下。…

day14I102.二叉树的层序遍历

1、102.二叉树的层序遍历 题目链接&#xff1a;https://leetcode.cn/problems/binary-tree-level-order-traversal/ 文章链接&#xff1a;https://programmercarl.com/0102.%E4%BA%8C%E5%8F%89%E6%A0%91%E7%9A%84%E5%B1%82%E5%BA%8F%E9%81%8D%E5%8E%86.html#%E7%AE%97%E6%B3%95…

C# AnimeGANv2 人像动漫化

效果 项目 下载 可执行程序exe下载 源码下载 其他 C# 人像卡通化 Onnx photo2cartoon-CSDN博客 C# AnimeGAN 漫画风格迁移 动漫风格迁移 图像卡通化 图像动漫化_天天代码码天天的博客-CSDN博客

最新AI创作系统ChatGPT源码+详细搭建部署教程,支持AI绘画/支持OpenAI-GPT全模型+国内AI全模型

一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统AI绘画系统&#xff0c;支持OpenAI GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美&#xff0c;可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署…

SpringSecurity + jwt + vue2 实现权限管理 , 前端Cookie.set() 设置jwt token无效问题(已解决)

问题描述 今天也是日常写程序的一天 , 还是那个熟悉的IDEA , 还是那个熟悉的Chrome浏览器 , 还是那个熟悉的网站 , 当我准备登录系统进行登录的时候 , 发现会直接重定向到登录页 , 后端也没有报错 , 前端也没有报错 , 于是我得脸上又多了一张痛苦面具 , 紧接着在前端疯狂debug…

uniapp 使用和引入 thorui

1. npm install thorui-uni 2. "easycom": { "autoscan": true, "custom": { "tui-(.*)": "thorui-uni/lib/thorui/tui-$1/tui-$1.vue" } }, 3.

XML外部实体注入攻击XXE

xml是扩展性标记语言&#xff0c;来标记数据、定义数据类型&#xff0c;是一种允许用户对自己的标记语言进行定义的源语言。XML文档结构包括XML声明、DTD文档类型定义&#xff08;可选&#xff09;、文档元素&#xff0c;一般无法直接打开&#xff0c;可以选择用excl或记事本打…

CentOS 7下JumpServer安装及配置(超详细版)

前言 Jumpserver是一种用于访问和管理远程设备的Web应用程序&#xff0c;通常用于对服务器进行安全访问。它基于SSH协议&#xff0c;提供了一个安全和可管理的环境来管理SSH访问。Jumpserver是基于Python开发的一款开源工具&#xff0c;其提供了强大的访问控制功能&#xff0c;…

NLP项目:维基百科文章爬虫和分类【02】 - 语料库转换管道

一、说明 我的NLP项目在维基百科条目上下载、处理和应用机器学习算法。相关上一篇文章中&#xff0c;展示了项目大纲&#xff0c;并建立了它的基础。首先&#xff0c;一个 Wikipedia 爬网程序对象&#xff0c;它按名称搜索文章&#xff0c;提取标题、类别、内容和相关页面&…

CSS 中几种常用的换行方法

1、使用 br 元素&#xff1a; 最简单的换行方法是在需要换行的位置插入 元素。例如&#xff1a; <p>This is a sentence.<br>It will be on a new line.</p>这会在 “This is a sentence.” 和 “It will be on a new line.” 之间创建一个换行。 效果&a…

云表:MES系统是工业4.0数字化转型的核心

随着信息技术与工业技术的深度融合&#xff0c;网络、计算机技术、信息技术、软件与自动化技术相互交织&#xff0c;产生了全新的价值模式。在制造领域&#xff0c;这种资源、信息、物品和人相互关联的模式被德国人定义为“工业4.0”&#xff0c;也就是第四次工业革命。工业4.0…

C#中的Dispatcher:Invoke与BeginInvoke的使用

Dispatcher是.NET框架中的一个重要概念&#xff0c;用于处理异步消息传递。在C#中&#xff0c;Dispatcher提供了两种方法&#xff1a;Invoke和BeginInvoke&#xff0c;用于控制线程上消息的顺序和执行方式。 目录 一、Dispatcher.Invoke二、Dispatcher.BeginInvoke三、使用场景…

flink1.15 savepoint 超时报错 java.util.concurrent.TimeoutException

savepoint命令 flink savepoint e04813d4e7480c526912eb4d32bba510 hdfs://flink/flink/migration/savepoint56650 -Dyarn.application.id=application_1683808492336_1222报错内容 org.apache.flink.util.FlinkException: Triggering a savepoint for the job e04813d4e7480…

Excel 自动提取某一列不重复值

IFERROR(INDEX($A$1:$A$14,MATCH(0,COUNTIF($C$1:C1,$A$1:$A$14),0)),"")注意&#xff1a;C1要空置&#xff0c;从C2输入公式 参考&#xff1a; https://blog.csdn.net/STR_Liang/article/details/105182654 https://zhuanlan.zhihu.com/p/55219017?utm_id0

【Unity】【VR】详解Oculus Integration输入

【背景】 以下内容适用于Oculus Integration开发VR场景,也就是OVR打头的Scripts,不适用于OpenXR开发场景,也就是XR打头Scripts。 【详解】 OVR的Input相对比较容易获取。重点在于区分不同动作机制的细节效果。 OVR Input的按键存在Button和RawButton两个系列 RawButton…

vue-mixin

1.vue中&#xff0c;混入(mixin)是一种特殊的使用方式。一个混入对象可以包含任意的组件配置选项(data, props, components, watch,computed…)可以根据需求"封装"一些可复用的单元&#xff0c;并在使用时根据一定的策略合并到组件的选项中&#xff0c;使用时和组件自…

十六、【橡皮擦工具组】

文章目录 橡皮擦背景橡皮擦1. 一次取样2. 连续取样3. 取样背景色板 魔术橡皮擦 橡皮擦 橡皮擦跟我们平常生活中所用的橡皮擦是一样&#xff0c;它是将图层的内容擦除,只剩下空白部分。另外当我们按住Alt的键去擦除空白部分的时候&#xff0c;也可以将背景的部分显示出来。 另…

Maven系列第4篇:仓库详解

maven系列目标&#xff1a;从入门开始开始掌握一个高级开发所需要的maven技能。 这是maven系列第4篇。 整个maven系列的内容前后是有依赖的&#xff0c;如果之前没有接触过maven&#xff0c;建议从第一篇看起&#xff0c;本文尾部有maven完整系列的连接。 环境 maven3.6.1 …