【TensorFlow2 之013】TensorFlow-Lite

一、说明

        在这篇文章中,我们将展示如何构建计算机视觉模型并准备将其部署在移动和嵌入式设备上。有了这些知识,您就可以真正将脚本部署到日常使用或移动应用程序中。

        教程概述:

  1. 介绍
  2. 在 TensorFlow 中构建模型
  3. 将模型转换为 TensorFlow Lite
  4. 训练后量化

二、前提知识

      上次,我们展示了如何使用迁移学习来提高模型性能。但是,当我们可以在智能手机或其他嵌入式设备上使用我们的模型时,为什么我们只使用我们的模型来预测计算机上猫或狗的图像呢?

        TensorFlow Lite 是 TensorFlow 针对移动和嵌入式设备的轻量级解决方案。它使我们能够在移动设备上以低延迟快速运行机器学习模型,而无需访问服务器。

        它可以通过 C++ API 在 Android 和 iOS 设备上使用,也可以与 Android 开发人员的 Java 包装器一起使用。

三、 在 TensorFlow 中构建模型

        在开始使用 TensorFlow Lite 之前,我们需要训练一个模型。我们将使用台式计算机或云平台等更强大的机器从一组数据训练模型。然后,可以导出该模型并在移动设备上使用。

        让我们首先准备训练数据集。我们将使用 wget .download 命令下载数据集。之后,我们需要解压缩它并合并训练和测试部分的路径。

import os
import wget
import zipfile wget.download("https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip")
100% [........................................................................] 68606236 / 68606236

Out[2]:

'cats_and_dogs_filtered.zip'
with zipfile.ZipFile("cats_and_dogs_filtered.zip","r") as zip_ref:zip_ref.extractall()base_dir = 'cats_and_dogs_filtered'train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

        

        接下来,让我们导入所有必需的库并构建我们的模型。我们将使用一个名为MobileNetV2的预训练网络 ,该网络是在 ImageNet 数据集上进行训练的。之后,我们需要冻结预训练层并添加一个名为GlobalAveragePooling2D 的新层,然后是具有 sigmoid 激活函数的密集层。我们将使用图像数据生成器来迭代图像。

base_model = MobileNetV2(input_shape=(224, 224, 3),include_top=False,weights='imagenet')base_model.trainable = False
model = Sequential([base_model,GlobalAveragePooling2D(),Dense(1, activation='sigmoid')])
model.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=0.001),metrics=['accuracy'])train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(224, 224),batch_size=32,class_mode='binary')validation_generator = val_datagen.flow_from_directory(validation_dir,target_size=(224, 224),batch_size=32,class_mode='binary')
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
history = model.fit(train_generator,epochs=6,validation_data=validation_generator,verbose=2)
Train for 63 steps, validate for 32 steps
Epoch 1/6
63/63 - 470s - loss: 0.3610 - accuracy: 0.8515 - val_loss: 0.1664 - val_accuracy: 0.9460
Epoch 2/6
63/63 - 139s - loss: 0.2055 - accuracy: 0.9210 - val_loss: 0.2065 - val_accuracy: 0.9150
Epoch 3/6
63/63 - 89s - loss: 0.1507 - accuracy: 0.9470 - val_loss: 0.1294 - val_accuracy: 0.9560
Epoch 4/6
63/63 - 81s - loss: 0.1342 - accuracy: 0.9510 - val_loss: 0.1733 - val_accuracy: 0.9350
Epoch 5/6
63/63 - 80s - loss: 0.1205 - accuracy: 0.9555 - val_loss: 0.1684 - val_accuracy: 0.9390
Epoch 6/6
63/63 - 82s - loss: 0.1079 - accuracy: 0.9620 - val_loss: 0.1418 - val_accuracy: 0.9450

四、 将模型转换为 TensorFlow Lite

        在 TensorFlow 中训练深度神经网络后,我们可以将其转换为 TensorFlow Lite。

        这里的想法是使用提供的 TensorFlow Lite 转换器来转换模型并生成 TensorFlow Lite FlatBuffer文件。此类文件的扩展名是.tflite

        可以直接从经过训练的 tf.keras 模型、保存的模型或具体的 TensorFlow 函数进行转换,这也是可能的。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()open("converted_model.tflite", "wb").write(tflite_model)

        转换后,FlatBuffer 文件可以部署到客户端设备。在客户端设备上,它可以使用 TensorFlow Lite 解释器在本地运行。

https://www.tensorflow.org/lite/convert

        让我们看看如何使用TensorFlow Lite Interpreter直接从 Python 运行该模型。

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

        我们可以通过运行get_input_detailsget_output_details函数来获取模型的输入和输出详细信息。我们还可以调整输入和输出的大小,这样我们就可以对整批图像进行预测。

input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])
== Input details ==
shape: [  1 224 224   3]
type: <class 'numpy.float32'>== Output details ==
shape: [1 1]
type: <class 'numpy.float32'>
interpreter.resize_tensor_input(input_details[0]['index'], (32, 224, 224, 3))
interpreter.resize_tensor_input(output_details[0]['index'], (32, 5))
interpreter.allocate_tensors()input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])
== Input details ==
shape: [ 32 224 224   3]
type: <class 'numpy.float32'>== Output details ==
shape: [32  1]
type: <class 'numpy.float32'>

        接下来,我们可以生成一批图像,在我们的例子中为 32 个,因为这是之前图像数据生成器中使用的数量。

        使用 TensorFlow Lite 进行预测的第一步是设置模型的输入,在我们的示例中是一批 32 张图像。设置输入张量后,我们可以调用invoke 命令来调用解释器。输出预测将通过调用get_tensor命令获得。

        我们来对比一下 TensorFlow Lite 模型和 TensorFlow 模型的结果。它们应该是相同的。我们将使用np.testing.assert_almost_equal函数来完成此操作。如果结果在一定的小数位数内不相同,则会引发错误。

mage_batch, label_batch = next(iter(validation_generator))
interpreter.set_tensor(input_details[0]['index'], image_batch)
interpreter.invoke()
tflite_results = interpreter.get_tensor(output_details[0]['index'])tf_results = model.predict(image_batch)for tf_result, tflite_result in zip(tf_results, tflite_results):np.testing.assert_almost_equal(tf_result, tflite_result, decimal=5)

        到目前为止没有错误,所以它向我们表明此转换已成功完成。但这是有阶级概率的。让我们编写一个函数将其转换为类。该函数将执行与 TensorFlow 中内置函数相同的操作。此外,我们还将展示这些预测。红色将显示与主 TensorFlow 模型不同的预测,蓝色图像将显示两个模型做出相同的预测。作为标签,我们将显示 TensorFlow Lite 模型的预测。

def predict_classes(tf_model, tf_lite_interpreter, x):tf_lite_interpreter.set_tensor(input_details[0]['index'], x)tf_lite_interpreter.invoke()tflite_results = tf_lite_interpreter.get_tensor(output_details[0]['index'])tf_results = model.predict_classes(x)if tflite_results.shape[-1] > 1:tflite_results = tflite_results.argmax(axis=-1)else:tflite_results = (tflite_results > 0.5).astype('int32')plt.figure(figsize=(12,10))for n in range(32):plt.subplot(6,6,n+1)plt.subplots_adjust(hspace = 0.3)plt.imshow(image_batch[n])color = "blue" if tf_results[n] == tflite_results[n] else "red"plt.title(tflite_results[n], color=color)plt.axis('off')_ = plt.suptitle("Model predictions (blue: same, red: different)")predict_classes(model, interpreter, image_batch)

模型预测 - 猫与狗

五、训练后量化

        这非常简单,但某些模型并未针对在低功耗设备上运行进行优化。因此,为了解决这个问题,我们需要进行量化。

        训练后量化是一种转换技术,可以减小模型大小,同时改善 CPU 和硬件加速器延迟,而模型精度几乎没有下降。

        下表显示了三种技术及其优点,以及可以运行该技术的硬件。

技术好处硬件
权重量化缩小 4 倍,加速 2-3 倍,准确度提高中央处理器
全整数量化缩小 4 倍,加速 3 倍以上CPU、边缘TPU等
Float16 量化体积缩小 2 倍,具有潜在的 GPU 加速能力中央处理器/图形处理器

        让我们看看它是如何进行权重量化的,这是默认的。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()open("converted_model.tflite", "wb").write(tflite_quant_model)interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()interpreter.resize_tensor_input(input_details[0]['index'], (32, 224, 224, 3))
interpreter.resize_tensor_input(output_details[0]['index'], (32, 5))
interpreter.allocate_tensors()input_details = interpreter.get_input_details()
output_details= interpreter.get_output_details()
print("== Input details ==")
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])
print("\n== Output details ==")
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])
== Input details ==
shape: [ 32 224 224   3]
type: <class 'numpy.float32'>== Output details ==
shape: [32  1]
type: <class 'numpy.float32'>

以下决策树可以帮助确定哪种训练后量化方法最适合您的用例。

决策树

        在我们的例子中,我们使用名为 MobileNetV2 的模型,该模型已经针对低功耗设备进行了优化。所以本例中的权重量化会降低模型精度。

        和之前一样,让我们​​看看这次模型的表现如何。红色将显示与主 TensorFlow 模型不同的预测,蓝色图像将显示两个模型做出相同的预测。作为标签,我们将显示 TensorFlow Lite 模型的预测。

predict_classes(model, interpreter, image_batch)

模型预测 2 - 猫与狗

六、总结

        总而言之,在这篇文章中,我们讨论了从 TensorFlow 到 TensorFlow Lite 的转换以及低功耗设备的优化。在下一篇文章中,我们将展示如何在 TensorFlow 2.0 中实现 LeNet-5。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/159294.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mac 远程 Ubuntu

1. Iterm2 添加ssh 参考&#xff1a;https://www.javatang.com/archives/2021/11/29/13063392.html 2. Finder 添加远程文件管理 2.1 ubuntu 配置 安装samba sudo apt-get install samba配置 [share]path /home/USER_NAME/shared_directoryavailable yesbrowseable ye…

NewStarCTF2023week2-ez_sql

闭合之后尝试判断字段数&#xff0c;存在WAF&#xff0c;使用大小写绕过&#xff08;后面的sql语句也需要进行大小写绕过&#xff09; ?id1 Order by 5-- 测出有5列 ?id1 Order by 6-- 查一下数据库名、版本、用户等信息 ?id1Union Select database(),version(),user(),4,…

elementUI el-table+树形结构子节点选中后没有打勾?(element版本问题 已解决)

问题 1.不勾选父级CB111&#xff0c;直接去勾选子级&#xff08;ST2001…&#xff09;&#xff0c;子级选中后没有打勾显示 排查 一直以为是这个树形结构和表格不兼容产生的问题&#xff0c;到后来看官方demo都是可以勾选的&#xff0c;最后排查到了版本问题&#xff0c; 项…

数学建模——人工神经网络模型

一、人工神经网络简介 1、神经网络起源与应用 1943年心理学家McCulloch和数学家Pitts提出神经元生物数学模型&#xff08;M-P模型&#xff09;&#xff0c;后来人工神经网络(Artifical Neural Network,ANN)是在生物神经网络(Biological Neural Network,BNN)基础上发展起来的&a…

SystemC入门学习-第8章 测试平台的编写

之前的章节&#xff0c;一直把重点放在用SystemC来描述硬件电路上&#xff0c;即如何编写SystemC 的RTL。本章的注意力集中在验证和编写测试平台上。 重点包括&#xff1a; 如何生成时钟信号和激励波形如何编写有响应能力的测试平台如何记录仿真结果 8.1 编写测试平台 测试平…

IO流:java中解码和编码出现乱码说明及代码实现

IO流&#xff1a;java中解码和编码的代码实现 一、UTF-8和GBK编码方式二、idea和eclipse的默认编码方式三、解码和编码方法四、代码实现编码解码 五、额外知识扩展 一、UTF-8和GBK编码方式 如果采用的是UTF-8的编码方式&#xff0c;那么1个英文字母 占 1个字节&#xff0c;1个…

使用Swift开发Framework遇到的问题及解决方法

文章目录 一、Swift 旧版本Xcode 打出来的framework 新版本不兼容问题 一、Swift 旧版本Xcode 打出来的framework 新版本不兼容问题 Cannot load module xxx built with SDK ihphoneos16.4 when using SDK iphoneos17.0:XXX/xxx.framework/Modules/xxx.swiftmodule/arm64-appl…

java swing实现点击按钮切换图片(简单实现)

本文教程&#xff0c;主要提供一个简单的例子&#xff0c;使用java swing完成点击按钮能够切换图片。 目录 一、程序预览 二、程序代码 一、程序预览 二、程序代码 package learnProject.csdn;import java.awt.BorderLayout; import java.awt.FlowLayout; import java.awt.ev…

webpack 解决:Cannot use import statement outside a module 的问题

1、问题描述&#xff1a; 其一、报错为&#xff1a; Uncaught SyntaxError: Cannot use import statement outside a module; 中文为&#xff1a; 未捕获的语法错误&#xff1a;无法在模块外部使用 import 语句; 其二、问题描述为&#xff1a; 在项目打包的时候 npm run …

【Vue面试题二十一】、Vue中的过滤器了解吗?过滤器的应用场景有哪些?

文章底部有个人公众号&#xff1a;热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享&#xff1f; 踩过的坑没必要让别人在再踩&#xff0c;自己复盘也能加深记忆。利己利人、所谓双赢。 面试官&#xff1a;Vue中的过滤器了解吗&am…

Android组件通信——PendingIntent(二十八)

1. PendingIntent 1.1 知识点 &#xff08;1&#xff09;了解PendingIntent与Intent的区别&#xff1b; &#xff08;2&#xff09;可以完成Notification功能的开发&#xff1b; &#xff08;3&#xff09;可以使用PendingIntent进行短信的发送&#xff1b; 1.2 具体内容 …

asp.net老年大学信息VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio计算机毕业设计

一、源码特点 asp.net老年大学信息管理系统是一套完善的web设计管理系统&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为vs2010&#xff0c;数据库为sqlserver2008&#xff0c;使用c# 语言开发 asp.net老年大学信息管理系统…

[论文笔记]SimCSE

引言 今天带来一篇当时引起轰动的论文SimCSE笔记,论文题目是 语句嵌入的简单对比学习。 SimCSE是一个简单的对比学习框架,它可以通过无监督和有监督的方式来训练。 对于无监督方式,输入一个句子然后在一个对比目标中预测它自己,仅需要标准的Dropout作为噪声。这种简单的…

“Python+”集成技术高光谱遥感数据处理与机器学习深度应用

涵盖高光谱遥感数据处理的基础、python开发基础、机器学习和应用实践。重点解释高光谱数据处理所涉及的基本概念和理论&#xff0c;旨在帮助学员深入理解科学原理。结合Python编程工具&#xff0c;专注于解决高光谱数据读取、数据预处理、高光谱数据机器学习等技术难题&#xf…

springboot中如何在测试环境下进行web环境模拟测试

web环境模拟测试 模拟端口 SpringBootTest(webEnvironment SpringBootTest.WebEnvironment.RANDOM_PORT) public class WebTest {Testvoid testRandomPort () {} }

docker 搭建本地Chat GPT

要在CentOS7上安装Docker&#xff0c;您可以按照以下步骤进行操作&#xff1a; 1、更新系统包列表 sudo yum update2、安装Docker存储库的必要软件包 sudo yum install -y yum-utils device-mapper-persistent-data lvm23、添加Docker存储库 sudo yum-config-manager --add…

MySQL MVCC详细介绍

MVCC概念 MVCC(Multi-Version Concurrency Control) 多版本并发控制&#xff0c;是一种并发控制机制,用于处理数据库中的并发读写操作&#xff0c;它通过在每个事务中创建数据的快照&#xff0c;实现了读写操作的隔离性&#xff0c;从而避免了读写冲突和数据不一致的问题。 M…

GaN器件的工作原理

目录 AlGaN/GaNHEMT 器件工作原理&#xff08;常开-耗尽型器件&#xff09;常关 AlGaN/GaN 功率晶体管&#xff08;增强型器件&#xff09;HD-GIT与SP-HEMT AlGaN/GaNHEMT 器件工作原理&#xff08;常开-耗尽型器件&#xff09; 来源&#xff1a;毫米波GaN基功率器件及MMIC电路…

百度SEO优化的特点(方式及排名诀窍详解)

百度SEO优化的特点介绍&#xff1a; 百度SEO优化是指对网站进行优化&#xff0c;使其在百度搜索引擎中获得更好的排名&#xff0c;进而获取更多的流量和用户。百度SEO优化的特点是综合性强、效果持久、成本低廉、投资回报高。百度的搜索算法不断更新&#xff0c;所以长期稳定的…

宝塔面板服务器内存使用率高的三招解决方法

卸载多余PHP版本。假若安装了多个PHP版本&#xff0c;甚至把 php 5.3、5.4、7.0、7.3 全都安装上了&#xff0c;就会严重增加系统负载和内存使用率。 安装memcached 缓存组件&#xff0c;建议在宝塔面板后台直接安装。 卸载不常用软件。如&#xff1a;宝塔运维、宝塔一键安装…