任何使用 Keras 进行迁移学习

在前面的文章中,我们介绍了如何使用 Keras 构建和训练全连接神经网络(MLP)、卷积神经网络(CNN)和循环神经网络(RNN)。本文将带你深入学习如何使用 迁移学习(Transfer Learning) 来加速和提升模型性能。我们将使用 Keras 和预训练的卷积神经网络(如 VGG16)来完成一个图像分类任务。

目录

  1. 什么是迁移学习
  2. 环境准备
  3. 导入必要的库
  4. 加载和预处理数据
  5. 加载预训练模型
  6. 构建迁移学习模型
  7. 编译模型
  8. 训练模型
  9. 评估模型
  10. 保存和加载模型
  11. 总结

1. 什么是迁移学习

迁移学习 是一种机器学习技术,它利用在一个任务上训练的模型来解决另一个相关任务。通过迁移学习,我们可以:

  • 加速训练: 利用预训练模型的特征提取能力,减少训练时间。
  • 提高性能: 在数据量有限的情况下,迁移学习可以显著提高模型的泛化能力。
  • 减少数据需求: 预训练模型已经在大规模数据集上训练过,可以减少对新数据的需求。

在图像分类任务中,迁移学习通常涉及使用在 ImageNet 等大型数据集上预训练的卷积神经网络(如 VGG16、ResNet、Inception 等),并将其应用到新的图像分类任务中。

2. 环境准备

确保你已经安装了 Python(推荐 3.6 及以上版本)和 TensorFlow(Keras 已集成在 TensorFlow 中)。如果尚未安装,请运行以下命令:

pip install tensorflow

3. 导入必要的库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, applications
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
  • tensorflow: 深度学习框架,Keras 已集成其中。
  • ImageDataGenerator: 用于数据增强和预处理。
  • applications: 预训练模型模块,包含 VGG16、ResNet 等。

4. 加载和预处理数据

我们将使用 猫狗数据集(Cats vs Dogs),这是一个二分类图像数据集,包含 25,000 张猫和狗的图片。我们将使用 Keras 的 ImageDataGenerator 进行数据增强和预处理。

# 数据集路径
train_dir = 'data/train'
validation_dir = 'data/validation'# 图像参数
img_height, img_width = 150, 150
batch_size = 32# 训练数据生成器(数据增强)
train_datagen = ImageDataGenerator(rescale=1./255,               # 归一化rotation_range=40,            # 随机旋转width_shift_range=0.2,        # 随机水平平移height_shift_range=0.2,       # 随机垂直平移shear_range=0.2,              # 随机剪切zoom_range=0.2,               # 随机缩放horizontal_flip=True,         # 随机水平翻转fill_mode='nearest'           # 填充方式
)# 测试数据生成器(仅归一化)
test_datagen = ImageDataGenerator(rescale=1./255)# 加载训练数据
train_generator = train_datagen.flow_from_directory(train_dir,target_size=(img_height, img_width),batch_size=batch_size,class_mode='binary'  # 二分类
)# 加载验证数据
validation_generator = test_datagen.flow_from_directory(validation_dir,target_size=(img_height, img_width),batch_size=batch_size,class_mode='binary'
)

说明:

  • 使用 ImageDataGenerator 进行数据增强,可以提高模型的泛化能力。
  • flow_from_directory 方法从目录中加载数据,目录结构应为 train_dir/class1/train_dir/class2/

5. 加载预训练模型

我们将使用预训练的 VGG16 模型,并冻结其卷积基(convolutional base),只训练顶部的全连接层。

# 加载预训练的 VGG16 模型,不包括顶部的全连接层
conv_base = applications.VGG16(weights='imagenet',include_top=False,input_shape=(img_height, img_width, 3))# 冻结卷积基
conv_base.trainable = False# 查看模型结构
conv_base.summary()

说明:

  • weights='imagenet': 使用在 ImageNet 数据集上预训练的权重。
  • include_top=False: 不包括顶部的全连接层,以便我们添加自己的分类器。
  • conv_base.trainable = False: 冻结卷积基,防止其权重在训练过程中被更新。

6. 构建迁移学习模型

我们将添加自己的全连接层来进行分类。

model = models.Sequential([conv_base,  # 预训练的卷积基layers.Flatten(),  # 展平层layers.Dense(256, activation='relu'),  # 全连接层layers.Dropout(0.5),  # Dropout 层,防止过拟合layers.Dense(1, activation='sigmoid')  # 输出层,二分类
])# 查看模型结构
model.summary()

说明:

  • 添加 Flatten 层将多维输出展平。
  • 添加 Dense 层和 Dropout 层进行分类。
  • 输出层使用 sigmoid 激活函数进行二分类。

7. 编译模型

model.compile(optimizer=keras.optimizers.Adam(),loss='binary_crossentropy',metrics=['accuracy'])

说明:

  • 使用 Adam 优化器和二元交叉熵损失函数。
  • 评估指标为准确率。

8. 训练模型

# 设置训练参数
epochs = 10# 训练模型
history = model.fit(train_generator,steps_per_epoch=train_generator.samples // batch_size,epochs=epochs,validation_data=validation_generator,validation_steps=validation_generator.samples // batch_size
)

说明:

  • steps_per_epoch: 每个 epoch 的步数,通常为训练样本数除以批量大小。
  • validation_steps: 每个 epoch 的验证步数,通常为验证样本数除以批量大小。

9. 评估模型

test_loss, test_acc = model.evaluate(validation_generator, steps=validation_generator.samples // batch_size)
print(f"\n测试准确率: {test_acc:.4f}")

10. 保存和加载模型

# 保存模型
model.save("cats_vs_dogs_transfer_learning.h5")# 加载模型
new_model = keras.models.load_model("cats_vs_dogs_transfer_learning.h5")

11. 可视化训练过程

# 绘制训练 & 验证的准确率和损失值
plt.figure(figsize=(12,4))# 准确率
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend(loc='lower right')
plt.title('训练与验证准确率')# 损失值
plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend(loc='upper right')
plt.title('训练与验证损失')plt.show()

12. 解冻部分卷积基进行微调

为了进一步提高模型性能,可以解冻部分卷积基,进行微调。

# 解冻最后几个卷积层
conv_base.trainable = True# 查看可训练的参数
for layer in conv_base.layers:if layer.name == 'block5_conv1':breaklayer.trainable = False# 重新编译模型
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # 使用较低的学习率loss='binary_crossentropy',metrics=['accuracy'])# 继续训练模型
history_fine = model.fit(train_generator,steps_per_epoch=train_generator.samples // batch_size,epochs=5,validation_data=validation_generator,validation_steps=validation_generator.samples // batch_size
)

说明:

  • 解冻部分卷积层,并使用较低的学习率进行微调。
  • 继续训练模型以微调预训练模型的权重。

13. 课程回顾

本文其实不算什么知识点,只是利用迁移学习来加速训练的一个实际操作的例子。

作者简介

前腾讯电子签的前端负责人,现 whentimes tech CTO,专注于前端技术的大咖一枚!一路走来,从小屏到大屏,从 Web 到移动,什么前端难题都见过。热衷于用技术打磨产品,带领团队把复杂的事情做到极简,体验做到极致。喜欢探索新技术,也爱分享一些实战经验,帮助大家少走弯路!

温馨提示:可搜老码小张公号联系导师

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/471450.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

web安全测试渗透案例知识点总结(上)——小白入狱

目录 一、Web安全渗透测试概念详解1. Web安全与渗透测试2. Web安全的主要攻击面与漏洞类型3. 渗透测试的基本流程 二、知识点详细总结1. 常见Web漏洞分析2. 渗透测试常用工具及其功能 三、具体案例教程案例1:SQL注入漏洞利用教程案例2:跨站脚本&#xff…

浪潮信息“源”Embedding模型登顶MTEB榜单第一名

在自然语言处理(NLP)和机器学习领域,Embedding模型是将文本数据转换为高维向量表示的核心技术,直接影响NLP任务(如文本分类、情感分析等)的效果,对于提升模型性能和深入理解文本语义具有至关重要…

catchadmin-webman 宝塔 部署

1:宝塔的php 中删除禁用函数 putenv 问题: 按照文档部署的时候linux(php) vue (本地) 无法访问后端api/login 的接口 。 解决办法: webman 没有配置nginx 反向代理 配置就能正常访问了

【AutoGen 】简介

学习笔记AutoGen。它可以使用多个代理来开发 LLM 应用程序,这些代理可定制、可相互对话,可在各种模式下运行,且无缝允许人的参与,进一步在更大程度上为开发者提供助力。AutoGen 智能应用开发(一)|AutoGen 基础 学习笔记

【月之暗面kimi-注册/登录安全分析报告】

前言 由于网站注册入口容易被机器执行自动化程序攻击,存在如下风险: 暴力破解密码,造成用户信息泄露,不符合国家等级保护的要求。短信盗刷带来的拒绝服务风险 ,造成用户无法登陆、注册,大量收到垃圾短信的…

统信UOS开发接口DTK

DTK(Development ToolKit)是基于 Qt 开发的简单且实用的通用开发框架。提供丰富的开发接口与支持工具,能有效提升开发效率。 文章目录 一、简介DTK 常见模块介绍概述二、框架创建开发环境准备使用 cmake三、常见模块窗口和对话框一、简介 DTK 常见模块介绍 概述 DTK(Dev…

城市轨道交通数据可视化的应用与优势

通过图扑可视化技术将复杂的数据转化为易于理解的图像,助力交通管理者优化线路规划、提升运营效率和乘客信息服务。轨道交通管理者能够更直观地分析乘客流量、运营效率等关键指标,从而优化线路设计与调度,提高服务质量,为乘客提供…

【JavaEE初阶 — 多线程】生产消费模型 阻塞队列

1. 阻塞队列 (1) 阻塞队列 1. 概念 阻塞队列是一种特殊的队列,也遵守"先进先出"的原则;阻塞队列能是一种线程安全的数据结构,主要用来阻塞队列的插入和获取操作: 当队列满了的时候,插入操作会被…

重构开发之道,Blackbox.AI为技术注入智能新动力

本文目录 一、引言二、Blackbox.AI实战体验2.1 基于网页界面生成前端代码进行应用开发2.2 与AI助手实现实时智能对话2.3 重塑大型文件交互方式2.4 链接Github仓库进行对话编程 三、总结 一、引言 在生产力工具加速进化的浪潮中,Blackbox.AI开始崭露头角&#xff0c…

idea 弹窗 delete remote branch origin/develop-deploy

想删除远程分支,就选delete,仅想删除本地分支,选cancel; 在 IntelliJ IDEA 中遇到弹窗提示删除远程分支 origin/develop-deploy,这通常是在 Git 操作过程中出现的情况,可能是在执行如 git branch -d 或其他…

第四十五章 Vue之Vuex模块化创建(module)

目录 一、引言 二、模块化拆分创建方式 三、模块化拆分完整代码 3.1. index.js 3.2. module1.js 3.3. module2.js 3.4. module3.js 3.5. main.js 3.6. App.vue 3.7. Son1.vue 3.8. Son2.vue 四、访问模块module的state ​五、访问模块中的getters ​六、mutati…

【OpenEuler】配置虚拟ip

OpenEuler系统手动配置虚ip 介绍操作方法临时生效永久生效 验证 介绍 我们知道通过keepalived服务可以为linux服务器设置虚拟ip,但是有些特殊场景下若无法安装部署keepalived服务,则需要通过手动设置的方式,配置服务器的虚拟ip。 本方案提供…

CCI3.0-HQ:用于预训练大型语言模型的高质量大规模中文数据集

摘要 我们介绍了 CCI3.0-HQ,它是中文语料库互联网 3.0(CCI3.0)的一个高质量500GB子集,采用新颖的两阶段混合过滤管道开发,显著提高了数据质量。为了评估其有效性,我们在不同数据集的100B tokens上从头开始…

fastadmin多个表crud连表操作步骤

1、crud命令 php think crud -t xq_user_credential -u 1 -c credential -i voucher_type,nickname,user_id,voucher_url,status,time --forcetrue2、修改控制器controller文件 <?phpnamespace app\admin\controller;use app\common\controller\Backend;/*** 凭证信息…

安装SQL server中python和R

这两个都是编程语言 R 是一种专门为统计计算和数据分析而设计的语言&#xff0c;它具有丰富的统计函数和绘图工具&#xff0c;常用于学术研究、数据分析和统计建模等领域。 Python 是一种通用型编程语言&#xff0c;具有简单易学、语法简洁、功能强大等特点。它在数据科学、机…

项目技术栈-解决方案-web3去中心化

web3去中心化 Web3 DApp区块链:钱包:智能合约:UI:ETH系开发技能树DeFi应用 去中心化金融P2P 去中心化网络参考Web3 DApp 区块链: 以以太坊(Ethereum)为主流,也包括Solana、Aptos等其他非EVM链。 区块链本身是软件,需要运行在一系列节点上,这些节点组成P2P网络或者半…

【linux】centos7 换阿里云源

查看yum配置文件 yum的配置文件通常位于/etc/yum.repos.d/目录下。你可以使用以下命令查看这些文件&#xff1a; ls /etc/yum.repos.d/ # 或者 ll /etc/yum.repos.d/备份当前的yum配置文件 建议备份当前的yum配置文件&#xff1a; sudo cp /etc/yum.repos.d/CentOS-Base.re…

Python 中.title()函数和.lower()函数

一.title()函数 1.title()函数的功能 将字符串中的每一单词的首字母大写 2.举例 S1"i love you" S2S1.title() print(S2)3.输出 二.lower()函数 1.lower()函数的功能 将字符串中的每一大写字母都变成的小写字母 2.举例 S1"I LOVE YOU" S2S1.lower()…

[DEBUG] 服务器 CORS 已经允许所有源,仍然有 304 的跨域问题

背景 今天有一台服务器到期了&#xff0c;准备把后端迁移到另一台服务器上&#xff0c;结果前端在测试的时候&#xff0c;出现了 304 的跨域问题。 调试过程中出现的问题&#xff0c;包括但不限于&#xff1a; set the request’s mode to ‘no-cors’Redirect is not allow…

【AI构思渲染】网络直播——建筑绘图大模型生成渲染图

家人们&#xff01;&#xff01;好消息来了&#xff01;&#xff01; 2024年11月19日&#xff0c;上午10:00-11:00 构力学堂将会给大家带来一场直播课《AI构思渲染第一课&#xff0c;建筑绘图大模型生成渲染图》 课程亮点&#xff1a; 1、AI插件相关介绍 2、AI构思渲染安装…