政安晨:【Keras机器学习实践要点】(三)—— 编写组件与训练数据

目录

介绍

编写组件

训练模型


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

介绍

通过 Keras,您可以编写自定义层、模型、度量指标、损失和优化器,并在同一代码库中跨 TensorFlow、JAX 和 PyTorch 运行

老规矩,咱们还是先准备环境(参考我本专栏目录中的文章,其中有搭建环境的部分):

政安晨:【TensorFlow与Keras实战演绎机器学习】专栏 —— 目录icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136985399

准备好环境后,咱们开始。

编写组件

让我们先来看看自定义层

{keras.ops 命名空间包含}
1. NumPy API 的实现,例如 keras.ops.stack 或 keras.ops.matmul
2. 一组 NumPy 中没有的神经网络特定操作,如 keras.ops.conv 或 keras.ops.binary_crossentropy

让我们创建一个可与所有后端配合使用的自定义密集层

class MyDense(keras.layers.Layer):def __init__(self, units, activation=None, name=None):super().__init__(name=name)self.units = unitsself.activation = keras.activations.get(activation)def build(self, input_shape):input_dim = input_shape[-1]self.w = self.add_weight(shape=(input_dim, self.units),initializer=keras.initializers.GlorotNormal(),name="kernel",trainable=True,)self.b = self.add_weight(shape=(self.units,),initializer=keras.initializers.Zeros(),name="bias",trainable=True,)def call(self, inputs):# Use Keras ops to create backend-agnostic layers/metrics/etc.x = keras.ops.matmul(inputs, self.w) + self.breturn self.activation(x)

接下来,让我们制作一个依赖于keras.random命名空间的自定义Dropout层

class MyDropout(keras.layers.Layer):def __init__(self, rate, name=None):super().__init__(name=name)self.rate = rate# Use seed_generator for managing RNG state.# It is a state element and its seed variable is# tracked as part of `layer.variables`.self.seed_generator = keras.random.SeedGenerator(1337)def call(self, inputs):# Use `keras.random` for random ops.return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)

接下来,让我们编写一个自定义子类模型,使用我们的两个自定义层:

class MyModel(keras.Model):def __init__(self, num_classes):super().__init__()self.conv_base = keras.Sequential([keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),keras.layers.MaxPooling2D(pool_size=(2, 2)),keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),keras.layers.GlobalAveragePooling2D(),])self.dp = MyDropout(0.5)self.dense = MyDense(num_classes, activation="softmax")def call(self, x):x = self.conv_base(x)x = self.dp(x)return self.dense(x)

让我们编译并适配它:

model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)model.fit(x_train,y_train,batch_size=batch_size,epochs=1,  # For speedvalidation_split=0.15,
)

现在咱们演绎如下

在本地的TensorFlow虚拟环境中,首先导入keras:

from tensorflow import keras

(可以在Jupyter Notebook中运行)

如果在演绎执行中出错,可能是Keras版本问题,使用如下命令升级keras

sudo pip install --upgrade keras

执行结果:

训练模型

在任意数据源上训练模型

所有的Keras模型都可以在各种数据来源上进行训练和评估,与您使用的后端无关。这包括:

NumPy数组 Pandas数据框 TensorFlow tf.data.Dataset对象 PyTorch DataLoader对象 Keras PyDataset对象 无论您使用TensorFlow、JAX还是PyTorch作为Keras后端,它们都可以工作。

让我们尝试使用PyTorch DataLoader:

import torch# Create a TensorDataset
train_torch_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_torch_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test)
)# Create a DataLoader
train_dataloader = torch.utils.data.DataLoader(train_torch_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(val_torch_dataset, batch_size=batch_size, shuffle=False
)model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)
model.fit(train_dataloader, epochs=1, validation_data=val_dataloader)

现在让我们尝试使用tf.data来完成这个任务

import tensorflow as tftrain_dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
)
test_dataset = (tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
)model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)
model.fit(train_dataset, epochs=1, validation_data=test_dataset)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/293030.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

手写简易操作系统(十七)--编写键盘驱动

前情提要 上一节我们实现了锁与信号量,这一节我们就可以实现键盘驱动了,访问键盘输入的数据也属于临界区资源,所以需要锁的存在。 一、键盘简介 之前的 ps/2 键盘使用的是中断驱动的,在当时,按下键盘就会触发中断&a…

Abaqus周期性边界代表体单元Random Sphere RVE 3D (Mesh)插件

插件介绍 Random Sphere RVE 3D (Mesh) - AbyssFish 插件可在Abaqus生成三维具备周期性边界条件(Periodic Boundary Conditions, PBC)的随机球体骨料及骨料-水泥界面过渡区(Interfacial Transition Zone, ITZ)模型。即采用周期性代表性体积单元法(Periodic Representative Vol…

1.8 python 模块 time、random、string、hashlib、os、re、json

ython之模块 一、模块的介绍 (1)python模块,是一个python文件,以一个.py文件,包含了python对象定义和pyhton语句 (2)python对象定义和python语句 (3)模块让你能够有逻辑地…

Cookie 与 Session

目录 一、获取Cookie/Session 1、理解Cookie 2、理解Session 3、Cookie 和 Session 的区别 4、获取Cookie 4.1 传统获取Cookie 4.2 简洁获取Cookie 5、Session 存储和获取 5.1 Session存储 5.2 Session读取 5.3 简洁获取 Session 一、获取Cookie/Session 1、理解Co…

【Linux】详解软硬链接

一、软硬链接的建立方法 1.1软链接的建立 假设在当前目录下有一个test.txt文件,要对其建立软链接,做法如下: ln就是link的意思,-s表示软链接,test.txt要建立软链接的文件名,后面跟上要建立的软链接文件名…

k8s1.28.8版本配置Alertmanager报警方式(邮件,企业微信)

文章目录 总结部署流程 Alertmanager 三大核心1. 分组告警2. 告警抑制3. 告警静默 报警过滤静默通知方案一:方案二: 抑制报警规则案例一 参考文档 自定义路由告警,分来自不同路由的告警,艾特不同的人员进行区分修改 alertmanager …

linux C:变量、运算符

linux C 文章目录 变量运算符 一、变量 [存储类型] 数据类型 标识符 值 标识符:由数字、字母、下划线组成的序列,不能以数字开头。 数据类型:基本数据类型构造类型 存储类型:auto static…

4月深圳振威新能源产业博览会丨千万订单采购对接会!

4月深圳振威新能源产业博览会丨千万订单采购对接会! 目前,振威新能源海外招商团队已成功与俄罗斯、泰国多家组织机构建立合作联系!已确定携多家知名企业到现场采购! 电池与储能 同时,振威新能源团队也成功与泰国储能技…

【KingSCADA】播放语音

1.函数介绍 PlaySound(string strWaveFileName, int nMode);下面是官方帮助文档中的解释: 2.生成语音文件 3.使用脚本播放音频文件 将音频文件存放在工程目录下面,我存放在了…\Resources\文件夹下: 我简单的写了一个定时1分钟播放一次语…

Docket常见的软件部署1

1 安装MySQL # 查看MySQL镜像 docker search mysql # 拉起镜像 docker pull mysql:5.7 # 创建MySQL数据映射卷,防止数据不丢失 mkdir -p /hmoe/tem/docker/mysql/data/ # 启动镜像 docker run -d --name mysql -e MYSQL_ROOT_PASSWORD123456 -p 3306:3306 -v /home…

蓝桥杯第七届大学B组详解

目录 1.煤球数量; 2.生日蜡烛; 3.凑算式 4.方格填数 5.四平方和 6.交换瓶子 7.最大比例 1.煤球数量 题目解析:可以根据题目的意思,找到规律。 1 *- 1个 2 *** 3个 3 ****** 6个 4 ********** 10个 不难发现 第…

OSCP靶场--Internal

OSCP靶场–Internal 考点(CVE-2009-3103) 1.nmap扫描 ## ┌──(root㉿kali)-[~/Desktop] └─# nmap 192.168.216.40 -sV -sC -Pn --min-rate 2500 -p- Starting Nmap 7.92 ( https://nmap.org ) at 2024-03-31 07:00 EDT Nmap scan report for 192.168.216.40 Host is up…

C++11新特性(二):更好用的 lambda 表达式和 function 包装器

目录 lambda 表达式 基本格式及参数列表 对于 lambda 捕捉列表的说明 function 包装器 bind 包装器 lambda 表达式 C11引入了lambda表达式,它是一种用于创建匿名函数的语法。lambda表达式可以被视为一个匿名函数对象,它可以在需要函数对象的地方使用…

PyTorch 教程-快速上手指南

文章目录 PyTorch Quickstart1.处理数据2.创建模型3.优化模型参数4.保存模型5.加载模型 PyTorch 基础入门1.Tensors1.1初始化张量1.2张量的属性1.3张量运算1.3.1张量的索引和切片1.3.2张量的连接1.3.3算术运算1.3.4单元素张量转变为Python数值 1.4Tensor与NumPy的桥接1.4.1Tens…

系统慢查询的思考

系统慢查询的思考 在一个系统中发现慢查询的功能或很卡的现象。你是怎么思考的?从哪几个方面去思考?会用什么工具? 一个系统使用了几年后都可能会出现这样的问题。原因可能有以下几点。 数据量的增加。系统中平时的使用中数据量是有一个累…

【AXIS】AXI-Stream FIFO设计实现(四)——异步时钟

前文介绍了几种同步时钟情况下的AXI Stream FIFO实现方式,一般来说,FIFO也需要承担异步时钟域模块间数据传输的功能,本文介绍异步AXIS FIFO的实现方式。 如前文所说,AXI-Stream FIFO十分类似于FWFT异步FIFO,推荐参考前…

MIPI CSI-2 Low Level Protocol解读

一、Low Level Protocol介绍 LLP 是一种面向字节的基于数据包的协议,支持使用短数据包和长数据包格式传输任意数据。为简单起见,本节中的所有示例均为单通道配置。 LLP特性: 传输任意数据(与有效载荷无关) 8 位字大…

Chatgpt掘金之旅—有爱AI商业实战篇(二)

演示站点: https://ai.uaai.cn 对话模块 官方论坛: www.jingyuai.com 京娱AI 一、前言: 成为一名商业作者是一个蕴含着无限可能的职业选择。在当下数字化的时代,作家们有着众多的平台可以展示和推广自己的作品。无论您是对写书、文…

MSPF5438数据卫星透传

最近在网上找了个项目来做,实现功能简单描述就是通过Lora模块E30-170T27D接收上位机发送的数据包,并对接收数据包进行正确性校验,若数据包校验成功则将其储存在W25Q125FV 中,待上位机发送数据包传输完毕指令后,单片机启…

Docker配置Mysql

1.首页搜索mysql镜像 2.选择对应版本的MySQL,点击pull 3.pull完成以后,点击images,这里可以看到刚刚pull完成的mysql版本 4.打开命令界面,运行命令 docker images ,查看当前已经pull的images 5.运行命令设置mysql docker run -it…