【TensorFlow2 之012】TF2.0 中的 TF 迁移学习

#012 TensorFlow 2.0 中的 TF 迁移学习

一、说明

        在这篇文章中,我们将展示如何在不从头开始构建计算机视觉模型的情况下构建它。迁移学习背后的想法是,在大型数据集上训练的神经网络可以将其知识应用于以前从未见过的数据集。也就是说,为什么它被称为迁移学习;我们将现有模型的学习转移到新的数据集中。

教程概述:

  1. 介绍
  2. 使用内置的 TensorFlow 模型进行迁移学习
  3. 使用 TensorFlow Hub 进行迁移学习

二、为什么迁移学习简介

之        我们已经探讨了如何使用数据增强来提高模型性能。现在的问题是,“如果我们没有足够的数据来从头开始训练我们的网络怎么办?

对        此的解决方案是使用迁移学习方法。一篇更具理论意义的帖子已经发布在我们的博客上。如果需要,请查看它以刷新一些想法。我们可以使用迁移学习将知识从一些预先训练好的开源网络转移到我们自己的简历问题中。

计        算机视觉研究社区在互联网上发布了许多数据集,如Imagenet或MS Coco或Pascal数据集。许多计算机视觉研究人员已经在这些数据集上训练了他们的算法。有时,此培训需要数周时间,并且可能需要许多 GPU。事实上,其他人已经完成了这项任务并经历了痛苦的高性能研究过程,这意味着我们经常可以下载开源权重。

有        很多网络已经过训练。例如,Imagenet 数据集,它由 1000 个不同的类和超过 14 万张图像组成。因此,网络可能有一个 softmax 单元,它输出一千个可能的类之一。我们可以做的是摆脱softmax层并创建我们自己的输出单元来表示例如猫或

        由于我们使用下载的权重,我们将只训练与我们的输出层关联的参数,在我们的例子中,这将是一个 sigmoid 输出层。

三、使用内置的 TensorFlow 模型进行迁移学习

        让我们首先为训练准备数据集。我们将使用 wget.download 命令下载数据集。之后,我们需要解压缩它并合并训练和测试部分的路径。

import os
import wget
import zipfile
wget.download("https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip")with zipfile.ZipFile("cats_and_dogs_filtered.zip","r") as zip_ref:zip_ref.extractall()base_dir = 'cats_and_dogs_filtered'train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

        现在,让我们导入所有必需的库并构建模型。我们将使用一个名为MobileNetV2的预训练网络,该网络在ImageNet数据集上进行训练。在这里,我们希望使用除顶部分类图层之外的所有图层,因此我们不会将它们包含在我们的网络中。

        obileNetV2 体系结构概述:https://ai.googleblog.com/2018/04/mobilenetv2-next-generation-of-on.html

这        个模型和其他预训练模型已经在TensorFlow中可用。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimgfrom tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGeneratorbase_model = MobileNetV2(input_shape=(224, 224, 3),include_top=False,weights='imagenet')base_model.summary()


输出:
Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 225, 225, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 112, 112, 32) 864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 112, 112, 32) 128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 112, 112, 32) 0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 112, 112, 32) 288         Conv1_relu[0][0]  

        果最后一层输出不同数量的类,那么我们需要有自己的输出单元来输出以下类:。有几种方法可以做到这一点:

  • 取最后几层的权重,将它们用作初始化并进行梯度下降。这样,我们将重新训练网络的一部分。
  • 删除最后几层的权重,使用我们自己的新隐藏单元和我们自己的最终 sigmoid(或 softmax)输出。通过这种方式,我们可以更改输出的数量。

        因此,这两种方法中的任何一种都值得尝试。

        现在让我们冻结预训练层并添加一个新层,称为 GlobalAveragePooling2D,之后是具有 sigmoid 激活函数的 Dense 层。

现        

base_model.trainable = Falsemodel = Sequential([base_model,GlobalAveragePooling2D(),Dense(1, activation='sigmoid')])
model.summary()

    

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_224 (Model) (None, 7, 7, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

          现在是训练步骤的时候了。我们将使用图像数据生成器来迭代图像。现在没有必要为大量时期训练网络,因为我们已经预先训练了一部分网络。

model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(224, 224),batch_size=32,class_mode='binary')validation_generator = val_datagen.flow_from_directory(validation_dir,target_size=(224, 224),batch_size=32,class_mode='binary')history = model.fit(train_generator,epochs=6,validation_data=validation_generator,verbose=2)

让        我们看看结果。

accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(accuracy))plt.plot(epochs, accuracy, label="Training")
plt.plot(epochs, val_accuracy, label="Validation")
plt.legend()
plt.title('Training and validation accuracy')
plt.figure()plt.plot(epochs, loss, label="Training")
plt.plot(epochs, val_loss, label="Validation")
plt.legend()
plt.title('Training and validation loss')

3.Text(0.5, 1.0, 'Training and validation loss')

四、 使用 TensorFlow Hub 进行迁移学习

        访问预训练模型的另一种方法是TensorFlow Hub。TensorFlow Hub是一个库,用于发布,发现和使用机器学习模型的可重用部分。您可以在此处找到更多预训练模型。

        我们将冻结图层并添加一个用于分类的新图层。全连接网络的输入称为瓶颈要素。它们表示网络中最后一个卷积层的激活图。

我们为给定数量的 epoch 训练模型。

        在这里,我们也可以使用 TensorBoard,但这次让我们保持简单。

        最后,让我们可视化一些预测。在此数据集中,猫标记为 0,狗标记为 1。我们将用蓝色显示正确的预测,用红色显示假。

预测

五、总结

        正如我们在上面看到的,使用迁移学习可以帮助我们在短时间内取得非常好的结果。使用数据增强,可以进一步增强结果。

         在下一篇文章中,我们将展示如何创建网络并将其转换以在移动设备上使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/159065.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯 第 1 场算法双周赛 第1题 三带一 c++ map 巧解 加注释

题目 三带一【算法赛】https://www.lanqiao.cn/problems/5127/learning/?contest_id144 问题描述 小蓝和小桥玩斗地主,小蓝只剩四张牌了,他想知道是否是“三带一”牌型。 所谓“三带一”牌型,即四张手牌中,有三张牌一样&#…

CSS学习基础知识

CSS学习笔记 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-width,…

独立式三相无源逆变电源设计

摘要 面对全球日趋严重的能源危机问题&#xff0c;可再生能源的开发和利用得到了人们的高度重视。其中辐射到地球太阳能资源是十分富饶的&#xff0c;绿色清洁的太阳能不会危害我们的生存环境&#xff0c;因而受到了人们的广泛利用。光伏发电作为可再生能源被广泛的应用&#x…

RabbitMq启用TLS

Windows环境 查看配置文件的位置 选择使用的节点 查看当前节点配置文件的配置 配置TLS 将证书放到同配置相同目录中 编辑配置文件添加TLS相关配置 [{ssl, [{versions, [tlsv1.2]}]},{rabbit, [{ssl_listeners, [5671]},{ssl_options, [{cacertfile,"C:/Users/17126…

如何定制化跑腿小程序源码

跑腿小程序源码为您提供了一个强大的起点&#xff0c;但要创建一个成功的本地服务平台&#xff0c;您通常需要对源码进行定制化。这篇文章将介绍如何定制化跑腿小程序源码&#xff0c;包括添加新功能、修改界面和优化用户体验。 选择合适的跑腿小程序源码 首先&#xff0c;您…

Linux查看端口号及进程信息

Linux查看端口号及进程 Linux查看端口号 netstat netstat -tuln显示当前正在监听的端口号以及相关的进程信息 ss ss -tuln与netstat类似&#xff0c;ss也可以用于显示当前监听的端口以及相关信息 isof isof -i :端口号端口号替换为具体要查找的端口号&#xff0c;显示该端…

Leetcode 75——1768.交替合并字符串 解题思路与具体代码【C++】

一、题目描述与要求 1768. 交替合并字符串 - 力扣&#xff08;LeetCode&#xff09; 题目描述 给你两个字符串 word1 和 word2 。请你从 word1 开始&#xff0c;通过交替添加字母来合并字符串。如果一个字符串比另一个字符串长&#xff0c;就将多出来的字母追加到合并后字符…

DOSBox和MASM汇编开发环境搭建

DOSBox和MASM汇编开发环境搭建 1 安装DOSBox2 安装MASM3 编译测试代码4 运行测试代码5 调试测试代码 本文属于《 X86指令基础系列教程》之一&#xff0c;欢迎查看其它文章。 1 安装DOSBox 下载DOSBox和MASM&#xff1a;https://download.csdn.net/download/u011832525/884180…

MySQL操作合集

数据库的操作 创建数据库 create database [if not exists] db_name [character set utf8] [collate utf8_general_ci];查看所有数据库 show databases;查看数据库的创建语句 show create database db_name;修改数据库 alter database db_name character set utf8 colla…

Linux之open/close/read/write/lseek记录

一、文件权限 这里不做过多描述&#xff0c;只是简单的记录&#xff0c;因为下面的命令会涉及到。linux下一切皆是文件包括文本、硬件设备、管道、数据库、socket等。通过ls -l 命令可以查看到以下信息 drwxrwxrwx 1 root root 0 Oct 10 17:06 open -rwxrwxrwx 1 root roo…

内网渗透——隧道代理

文章目录 代理代理使用场景VPS建立隧道frpMSF木马生成监听开启frp服务端和客户端执行exe木马文件 代理 实验环境&#xff1a; 攻击机kali&#xff1a;192.168.188.133&#xff08;NAT模式&#xff09; 模拟的公网服务器&#xff08;本机&#xff09;&#xff1a;10.9.75.239 …

【数据库——MySQL(实战项目1)】(4)图书借阅系统——触发器

目录 1. 简述2. 功能代码2.1 创建两个触发器&#xff0c;分别在借出或归还图书时&#xff0c;修改借阅人表中的已借数目(附加&#xff1a;借阅人表的总借书数、图书表的借阅次数以及更新图书表的图书状态为(已借出/在架上))字段&#xff1b;2.2 创建触发器&#xff0c;当借阅者…

相似性搜索:第 3 部分--混合倒排文件索引和产品量化

接续前文&#xff1a;相似性搜索&#xff1a;第 2 部分&#xff1a;产品量化 SImilarity 搜索是一个问题&#xff0c;给定一个查询的目标是在所有数据库文档中找到与其最相似的文档。 一、介绍 在数据科学中&#xff0c;相似性搜索经常出现在NLP领域&#xff0c;搜索引擎或推…

Codeforces Round 903 (Div. 3)

D. Divide and Equalize Example input Copy 7 5 100 2 50 10 1 3 1 1 1 4 8 2 4 2 4 30 50 27 20 2 75 40 2 4 4 3 2 3 1 output Copy YES YES NO YES NO YES NONote The first test case is explained in the problem statement. 很重要很重要的知识点&a…

Windows端口号被占用的查看方法及解决办法

Windows端口号被占用的查看方法及解决办法 Error starting ApplicationContext. To display the conditions report re-run your application with debug enabled. 2023-10-14 22:58:32.069 ERROR 6488 --- [ main] o.s.b.d.LoggingFailureAnalysisReporter : ***…

CustomNavBar 自定义导航栏视图

1. 创建偏好设置键 CustomNavBarTitlePreferenceKey.swift import Foundation import SwiftUI//State private var showBackButton: Bool true //State private var title: String "Title" //"" //State private var subtitle: String? "SubTitl…

算法练习13——跳跃游戏II

LeetCode 45 跳跃游戏 II 给定一个长度为 n 的 0 索引整数数组 nums。初始位置为 nums[0]。 每个元素 nums[i] 表示从索引 i 向前跳转的最大长度。换句话说&#xff0c;如果你在 nums[i] 处&#xff0c;你可以跳转到任意 nums[i j] 处: 0 < j < nums[i] i j < n 返回…

螺杆支撑座对注塑机的生产过程有哪些重要影响?

螺杆支撑座对注塑机的生产过程具有重要影响&#xff0c;主要体现在以下几个方面&#xff1a; 1、精度和稳定性&#xff1a;螺杆支撑座能够提高注塑机的精度和稳定性&#xff0c;从而保证塑料制品的品质和一致性。通过提供稳定的支撑和承载&#xff0c;螺杆支撑座可以减少机器运…

React18入门(第三篇)——React Hooks详解,React内置Hooks、自定义Hooks使用

文章目录 概述一、内置 Hook——useState1.1 响应式数据更新1.2 什么是 state1.3 state 特点&#xff08;一&#xff09;——异步更新1.4 state 特点&#xff08;二&#xff09;——可能会被合并1.5 state 特点&#xff08;三&#xff09;——不可变数据&#xff08;重要&#…

MySQL的各种锁

1. MySQL有遇到过死锁的问题吗&#xff1f;你是如何解决的&#xff1f; 死锁&#xff0c;就是两个或两个以上的线程在执行过程中&#xff0c;去争夺同一个共享资源导致互相等待的现象&#xff0c;在没有外部干预的情况下&#xff0c;线程会一直处于阻塞状态&#xff0c;无法往下…