神经网络实战--使用迁移学习完成猫狗分类

在这里插入图片描述

前言: Hello大家好,我是Dream。 今天来学习一下如何使用基于tensorflow和keras的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~

本文目录:

  • 一、加载数据集
    • 1.调用库函数
    • 2.加载数据集
    • 3.数据集管理
  • 二、猫狗数据集介绍
    • 1.猫狗数据集介绍:
    • 2.图片展示
  • 三、MobileNetV2网络介绍
    • 1.加载tensorflow提供的预训练模型
    • 2.轻量级网络——MobileNetV2
    • 3.MobileNetV2的网络模块
  • 四、搭建迁移学习
    • 1.训练
    • 2.训练结果可视化
    • 3.输出训练的准确率
    • 4.用cnn工具可视化一批数据的预测结果
    • 5.数据输出
    • 6.用cnn工具可视化一个数据样本的各层输出
    • 7.输出结果图像
  • 五、源码获取

说明:在此试验下,我们使用的是使用tf2.x版本,在jupyter环境下完成
在本文中,我们将主要完成以下任务:

  1. 实现基于tensorflow和keras的迁移学习

  2. 加载tensorflow提供的数据集(不得使用cifar10)

  3. 需要使用markdown单元格对数据集进行说明

  4. 加载tensorflow提供的预训练模型(不得使用vgg16)

  5. 需要使用markdown单元格对原始模型进行说明

  6. 网络末端连接任意结构的输出端网络

  7. 用图表显示准确率和损失函数

  8. 用cnn工具可视化一批数据的预测结果

  9. 用cnn工具可视化一个数据样本的各层输出

一、加载数据集

1.调用库函数

import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import cnn_utils
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Input,Dropout

2.加载数据集

数据集加载,数据是通过这个网站下载的猫狗数据集:http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip,实验中为了训练方便,我们取了一个较小的数据集。

path_to_zip = tf.keras.utils.get_file('data.zip',origin='http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip',extract=True,
)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')BATCH_SIZE = 32
IMG_SIZE = (160, 160)

3.数据集管理

使用image_dataset_from_director进行数据集管理,使用ImageDataGenerator训练过程中会出现错误,不知道是什么原因,就使用了原始的image_dataset_from_director方法进行数据集管理。

train_dataset = image_dataset_from_directory(train_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)validation_dataset = image_dataset_from_directory(validation_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)

二、猫狗数据集介绍

1.猫狗数据集介绍:

猫狗数据集包括25000张训练图片,12500张测试图片,包括猫和狗两种图片。在此次实验中为了训练方便,我们取了一个较小的数据集。 数据解压之后会有两个文件夹,一个是 “train”,一个是 “test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据。
在这里插入图片描述
在train文件夹里边是一些已经命名好的图像,有猫也有狗。而在test文件夹中是只有编号名的图像。
在这里插入图片描述

2.图片展示

下面是数据集中的图片展示:

class_names = ['cats', 'dogs']plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

三、MobileNetV2网络介绍

1.加载tensorflow提供的预训练模型

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)

2.轻量级网络——MobileNetV2

使用轻量级网络——MobileNetV2进行数据预处理 说明: MobileNetV2是基于倒置的残差结构,普通的残差结构是先经过 1x1 的卷积核把 feature map的通道数压下来,然后经过 3x3 的卷积核,最后再用 1x1 的卷积核将通道数扩张回去,即先压缩后扩张,而MobileNetV2的倒置残差结构是先扩张后压缩
在这里插入图片描述

3.MobileNetV2的网络模块

MobileNetV2的网络模块样子是这样的:
在这里插入图片描述
MobileNetV2是基于深度级可分离卷积构建的网络,它是将标准卷积拆分为了两个操作:深度卷积 和 逐点卷积,深度卷积和标准卷积不同,对于标准卷积其卷积核是用在所有的输入通道上,而深度卷积针对每个输入通道采用不同的卷积核,就是说一个卷积核对应一个输入通道,所以说深度卷积是depth级别的操作。而逐点卷积其实就是普通的卷积,只不过其采用1x1的卷积核。
MobileNetV2的模型如下图所示,其中t为Bottleneck内部升维的倍数,c为通道数,n为该bottleneck重复的次数,s为sride
在这里插入图片描述

其中,当stride=1时,才会使用elementwise 的sum将输入和输出特征连接(如下图左侧);stride=2时,无short cut连接输入和输出特征(下图右侧):
在这里插入图片描述

四、搭建迁移学习

1.训练

inital_input = tf.keras.applications.mobilenet_v2.preprocess_input
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=False,weights='imagenet')
base_model.trainable = False
base_model.summary()

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

2.训练结果可视化

用图表显示准确率和损失函数

# 训练结果可视化,用图表显示准确率和损失函数
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range=range(initial_epochs)
plt.figure(figsize=(8,8))
plt.subplot(2,1,1)
plt.plot(epochs_range, acc, label="Training Accuracy")
plt.plot(epochs_range, val_acc,label="Validation Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")plt.subplot(2,1,2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss,label="Validation Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

3.输出训练的准确率

# 输出训练的准确率
test_loss, test_accuracy = model.evaluate(test_dataset)
print('test accuracy: {:.2f}'.format(test_accuracy))

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

4.用cnn工具可视化一批数据的预测结果

label_dict = {0: 'cat',1: 'dog'
}test_image_batch, test_label_batch = test_dataset.as_numpy_iterator().next()
# 编码成uint8 以图片形式输出
test_image_batch = test_image_batch.astype('uint8')cnn_utils.plot_predictions(model, test_image_batch, test_label_batch, label_dict, 32, 5, 5)

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

5.数据输出

# 数据输出,数字化特征图
test_image_batch, test_label_batch = train_dataset.as_numpy_iterator().next()img_idx = 0
random_batch = np.random.permutation(np.arange(0,len(test_image_batch)))[:BATCH_SIZE]
image_activation = test_image_batch[random_batch[img_idx]:random_batch[img_idx]+1]cnn_utils.get_activations(base_model, image_activation[0])

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

6.用cnn工具可视化一个数据样本的各层输出

cnn_utils.display_activations(cnn_utils.get_activations(base_model, image_activation[0]))

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

7.输出结果图像

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

五、源码获取

关注此公众号:人生苦短我用Pythons,回复 神经网络源码获取源码,快点击我吧

🌲🌲🌲 好啦,这就是今天要分享给大家的全部内容了,我们下期再见!
❤️❤️❤️如果你喜欢的话,就不要吝惜你的一键三连了~

本期推荐:
Python自动化办公应用大全(ChatGPT版):从零开始教编程小白一键搞定烦琐工作(上下册)
在这里插入图片描述

抽奖方式:评论区随机抽取3位小伙伴免费送出,每人送两本(上下册,共六本)
参与方式:关注博主、点赞、收藏、评论区评论“人生苦短,我用Python!”切记要点赞+收藏,否则抽奖无效,每个人最多评论三次!)
活动截止时间:2023-05-27 20:00:00

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/43244.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Pytorch的猫狗分类

无偿分享~ 猫狗二分类文件下载地址 在下一章说 猫狗分类这个真是困扰我好几天,找了好多资料都是以TensorFlow的猫狗分类,但我们要求的是以pytorch的猫狗分类。刚开始我找到了也运行成功了觉得可以了,最后看了一眼实践要求傻眼了,…

猫舍 靶场

目标url:http://117.41.229.122:8003/?id1 第一步,判断是否存在sql注入漏洞 构造 ?id1 and 11 ,回车 页面返回正常 构造 ?id1 and 12 ,回车 页面不正常,初步判断这里 可能 存在一个注入漏洞 第二步:判断字段数 构造 ?id1 and 11 order…

Olaparib 有望治疗 UBQLN4 过表达型肿瘤

基因组的不稳定性是人类遗传病和癌症的一大特点。在这篇文章当中,研究人员在常染色体隐性遗传病家族中发现了有害的 UBQLN4 突变。蛋白酶体穿梭因子UBQLN4 被 ATM 磷酸化并与泛素化的 MRE11 相互作用,从而介导早期的同源重组修复 (HRR)。在体外和体内实验…

神经内分泌肿瘤治疗新进展,神经内分泌肿瘤进展

脑肿瘤复发 。 颅内及椎管内肿瘤【概述】1.原发性肿瘤:起源于头颅、椎管内各种组织结构,如颅骨、脑膜、脑组织、颅神经、脑血管、脑垂体、松果体、脉络丛、颅内结缔组织、胚胎残余组织、脊膜、脊神经、脊髓、脊髓血管及脂肪组织等&#xff…

Nature Cancer | 发现非肿瘤药物的抗癌潜力

今天给大家介绍美国Broad Institute of MIT and Harvard的 Todd R. Golub团队发表在Nature cancer上的一篇文章:“Discovering the anticancer potential of nononcology drugs by systematic viability profiling“。在这个研究中,作者试图创建一个公共…

文献分享:定义的肿瘤抗原特异性T细胞增强了个性化的TCR-T细胞治疗和免疫治疗反应的预测

《Defifined tumor antigen-specifific T cells potentiate personalized TCR-T cell therapy and prediction of immunotherapy response》 简介 从患者体内自然发生的肿瘤抗原特异性T(Tas)细胞中提取的T细胞受体(TCRs)设计的T细胞将靶向其肿瘤中的个人TSAs。为了建立这种个性…

Irinotecan和vandetanib在治疗胰腺癌患者时产生协同效应(Irinotecan and vandetanib create synergies for t)

1. 摘要 背景:在胰腺癌(PAAD)中最常突变的基因对是KRAS和TP53,文章的目标是阐明KRAS/TP53突变的多组学和分子动力学图景,并为KRAS和TP53突变的PAAD患者获得有前景的新药物。此外,文章根据多组学数据尝试发现KRAS与TP53之间可能的联系。    …

癌症/肿瘤免疫治疗最新研究进展(2022年4月)

近年来,免疫治疗一直都是国内外肿瘤治疗研究领域的火爆热点,可以称之为革命性的突破。 除了大家熟知的PD-1/PD-L1已经先后斩获了包括肺癌、胃肠道肿瘤、乳腺癌、泌尿系统肿瘤、皮肤癌、淋巴瘤等在内的近20大实体肿瘤,成为免疫治疗的第一张王牌…

MCE | 癌症诊断和靶向治疗的“遍地开花”

据研究报道,很多癌细胞分泌的外泌体 (Exosome) 比正常细胞分泌的多 10 倍以上。外泌体参与了癌症的发生、进展、转移和耐药性,并通过转运蛋白和核酸,建立与肿瘤微环境的联系。例如,外泌体可导致免疫逃逸,癌细胞的免疫逃…

泛癌分析·找出各个癌症的预后相关基因

泛癌分析找出各个癌症的预后相关基因 ` 其他相关文章: 万物皆可pan分析高分文章登山梯for循环的熟练操作 前言 pan分析的第二篇我想写一下如何在TCGA整个基因集内实现COX单因素分析,将所有的预后相关基因筛选出来,同时得到这些基因的基本参数、统计量等信息。这样的分析的…

饮食干预减轻癌症治疗相关症状和毒性

现代化疗,放射疗法在摧毁癌细胞的同时,对健康细胞也造成了伤害,引发相关毒性,反应例如便秘,腹泻,疲劳,恶心,呕吐等。 癌症患者的营养状况可能是癌症治疗相关毒性的核心决定因素&…

边缘计算,是在炒概念吗?

导读:边缘计算概念刚出来的时候,很多人的第一反应是“这是哪个行业组织或者公司为了拉动市场需求而创造出来的新词汇吧?” 边缘计算究竟是什么?为什么会有边缘计算?它是一个全新的概念吗?谁在担任边缘计算的…

移动边缘计算笔记

该篇文章是阅读《移动边缘计算综述》所整理的笔记和心得,仅供参考,欢迎指正。 移动边缘计算(MEC),mobile edgecomputing,后来慢慢过渡为“多接入边缘计算”(multi-access edge computing&#x…

关于边缘计算和边云协同,看这一篇就够了~

几年前,大多数人都期望将物联网部署至云端,这的确可以给个人用户带来便捷的使用体验,但构建企业级的物联网解决方案,仍然需要采用云计算和边缘计算的结合方案。与纯粹的云端解决方案相比,包含边缘侧的混合方案可以减少…

边缘计算简介以及几款边缘计算开源平台

边缘计算中的边缘(edge)指的是网络边缘上的计算和存储资源,这里的网络边缘与数据中心相对,无论是从地理距离还是网络距离上来看都更贴近用户。作为一种新的计算范式,边缘计算将计算任务部署于接近数据产生源的网络边缘…

什么是边缘计算?

注:本篇翻译自施巍松教授的论文《Edge Computing : Vision and Challenges》 目录 文章目录 摘要简介什么是边缘计算什么是边缘计算边缘计算的优点 案例研究云卸载视频分析智能家居智慧城市 机遇和挑战编程可行性命名数据抽象服务管理私密性最优化指标 小结 摘要 …

中国电信边缘计算最佳实践

大数据、云计算、AI 等新一代信息技术的高速发展,在为新兴互联网行业提供强劲驱动之外,也在引领传统行业实施数字化、智能能化转型,并催生出智能制造、智慧金融等一系列全新智能产业生态。在中国电信 MEC 平台中,中国电信正依托自…

为什么需要边缘计算?哪些场景需要边缘计算?

为什么需要边缘计算? 边缘计算(Edge Computing)是一种将数据处理和计算功能移到接近数据源头的边缘设备上进行的计算模式。相比传统的云计算模式,边缘计算能够在接近数据源头的地方进行实时的数据处理,这为计算机视觉…

边缘计算是啥?

边缘计算,指的是在靠近物或数据源头的一侧,采用网络、计算、存储、应用核心能力为一体的开放平台,就近提供服务。与云计算作对比有助于更好地理解其特性,如果说云计算是集中化、规模化的,那么边缘计算就是分布式的、去…

边缘计算(三)——边缘计算的解决方案

点击篮字关注我们 目前,市场上存在的边缘计算相关概念包括雾计算、边缘计算、多接入边缘计算/移动边缘计算、移动云计算等概念。这是边缘计算的第三篇,主要讲的内容是边缘计算的解决方案。 Cloud Foundry平台 Cloud Foundry是一款使用Ruby开发的开源Paas…