创建和探索VGG16模型

        PyTorch在torchvision库中提供了一组训练好的模型。这些模型大多数接受一个称为 pretrained 的参数,当这个参数为True 时,它会下载为ImageNet 分类问题调整好的权重。让我们看一下创建 VGG16模型的代码片段:

from torchvision import models
vgg = models.vggl6(pretrained=True)

        现在有了所有权重已经预训练好且可马上使用的VGG16模型。当代码第一次运行时,可能需要几分钟,这取决于网络速度。权重的大小可能在500MB左右。我们可以通过打印快速查看下 VGG16模型。当使用现代架构时,理解这些网络的实现方式非常有用。我们来看看这个模型:

VGG((features): Sequential((0):Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))(1):ReLU (inplace)(2):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))(3):ReLU(inplace)(4):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1))(5):Conv2d(64,128,kernel_size=(3,3),stride=(1,1),padding=(1,1))(6):ReLU(inplace)(7):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1))(8):ReLU(inplace)(9):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1))(10):Conv2d(128,256,kernel_size=(3,3),stride=(1,1),padding=(1,1))(11):ReLU(inplace)(12):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1))(13):ReLU(inplace)(14):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1))(15):ReLU(inplace)(16):MaxPool2d(size=(2,2),stride=(2,2)dilation=(1,1))(17):Conv2d(256,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(18):ReLU(inplace)(19):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(20):ReLU(inplace)(21):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(22):ReLU(inplace)(23):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1))(24):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(25):ReLU(inplace)(26):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(27):ReLU(inplace)(28):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(29):ReLU(inplace)(30):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1)))(classifier):Sequential((0):Linear(25088>4096)(1):ReLU(inplace)(2):Dropout(p=0.5)(3):Linear(4096->4096)(4):ReLU (inplace)(5):Dropout(p=0.5)(6):Linear(4096>1000))
)

        模型摘要包含了两个序列模型:features和classifiers。features和sequentia1模型包含了将要冻结的层。

冻结层

        下面冻结包含卷积块的features模型的所有层。冻结层中的权重将阻止更新这些卷积块的权重。由于模型的权重被训练用来识别许多重要的特征,因而我们的算法从第一个迭代开时就具有了这样的能力。使用最初为不同用例训练的模型权重的能力,被称为迁移学习。现在看一下如何冻结层的权重或参数:

for param in vgg.features.parameters():param.requires_grad = False

        该代码阻止优化器更新权重。

微调VGG16模型

        VGG16模型被训练为针对1000个类别进行分类,但没有训练为针对狗和猫进行分类。因此,需要将最后一层的输出特征从1000改为2。以下代码片段执行此操作:

vgg.classifier[6].out_features = 2

        vgg.classifier可以访问序列模型中的所有层,第6个元素将包含最后一个层。当训练VGG16模型时,只需要训练分类器参数。因此,我们只将classifier.parameters传入优化器,如下所示:

optimizer=
optim.SGD(vgg.classifier.parameters(),lr=0.0001,momentum=0.5)

训练VGG16模型

        我们已经创建了模型和优化器。由于使用的是Dogs vs. Cats数据集,因此可以使用相同的数据加载器和train函数来训练模型。请记住,当训练模型时,只有分类器内的参数会发生变化。下面的代码片段对模型进行了20轮的训练,在验证集上达到了98.45%的准确率:

train_losses, train_accuracy =[],[]
val_losses, val_accuracy =[],[]
for epoch in range(l,20):epoch_loss,epoch_accuracy=fit(epoch,vgg,train_data_loader,phase='training')val_epoch_loss,val_epoch_accuracy=fit(epoch,vgg,valid_data_loader,phase='validation')train_losses.append(epoch_loss)train_accuracy.append(epoch_accuracy)val_losses.append(val_epoch_loss)val_accuracy.append(val_epoch_accuracy)

        将训练和验证的损失可视化,如图5.19所示。

        将训练和验证的准确率可视化,如图5.20所示:

        我们可以应用一些技巧,例如数据增强和使用不同的dropout值来改进模型的泛化能力。以下代码片段将 VGG分类器模块中的dropout值从0.5更改为0.2并训练模型:

for layer in vgg.classifier.children():if(type(layer)== nn.Dropout):layer.p=0.2
#训练
train_losses,train_accuracy = [][]
val_losses, val accuracy =[],[ ]
for epoch in range(1,3):epoch_loss,epoch_accuracy=fit(epoch,vgg,train_data_loader,phase='training')val_epoch_loss,val_epoch_accuracy=fit(epoch,vgg,valid_data_loader,phase='validation')train_losses.append(epoch_loss)train_accuracy.append(epoch_accuracy)val_losses.append(val_epoch_loss)val_accuracy.append(val_epoch_accuracy)

        通过几轮的训练,模型得到了些许改进。还可以尝试使用不同的dropout值。改进模型泛化能力的另一个重要技巧是添加更多数据或进行数据增强。我们将通过随机地水平翻转图像或以小角度旋转图像来进行数据增强。torchvision转换为数据增强提供了不同的功能,它们可以动态地进行,每轮都发生变化。我们使用以下代码实现数据增强:

train transform =transforms.Compose([transforms,Resize((224,224)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(0.2),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
train = ImageFolder('dogsandcats/train/',train_transform)
valid = ImageFolder('dogsandcats/valid/',simple_transform)
#训练
train_losses,train_accuracy=[][]
val_losses,val_accuracy = [],[]
for epoch in range(1,3):epoch_loss,epoch_accuracy=fit(epoch,vgg,train_data_loader,phase='training')val_epoch_loss,val_epoch_accuracy=fit(epoch,vgg,valid_data_loader,phase='validation')train_losses.append(epoch_loss)train_accuracy.append(epoch_accuracy)val_losses.append(val_epoch_loss)val_accuracy.append(val_epoch_accuracy)

        前面的代码输出如下:

#结果
training loss is 0.041 and training accuracy is 22657/23000 98.51
validation loss is 0.043 and validation accuracy is 1969/2000 98.45
training loss is 0.04 and training accuracy is 22697/23000 98.68 
validation loss is 0.043 and validation accuracy is 1970/2000 98.5

        使用增强数据训练模型仅运行两轮就将模型准确率提高了0.1%;可以再运行几轮以进一步改进模型。如果大家在阅读本书时一直在训练这些模型,将意识到每轮的训练可能需要几分钟,具体取决于运行的GPU。让我们看一下可以在几秒钟内训练一轮的技术。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/359750.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视图(views)

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 下面通过一个例子讲解在Django项目中定义视图,代码如下: from django.http import HttpResponse # 导入响应对象 impo…

Flutter【组件】点击类型表单项

简介 flutter 点击表单项组件,适合用户输入表单的场景。 点击表单项组件是一个用户界面元素,通常用于表单或设置界面中,以便用户可以点击它们来选择或更改某些设置或输入内容。这类组件通常由一个标签和一个可点击区域组成,并且…

Redis-数据类型-zset

文章目录 1、查看redis是否启动2、通过客户端连接redis3、切换到db4数据库4、将一个或多个member元素及其score值加入到有序集key当中5、升序返回有序集key6、升序返回有序集key,让分数一起和值返回的结果集7、降序返回有序集key,让分数一起和值返回到结…

Charles抓包工具系列文章(一)-- Compose 拼接http请求

一、背景 众所周知,Charles是一款抓包工具,当然是http协议,不支持tcp。(如果你想要抓tcp包,请转而使用wireshark,在讲述websocket的相关技术有梳理过wireshark抓包) 话说回来,char…

浏览器自带的IndexDB的简单使用示例--小型学生管理系统

浏览器自带的IndexDB的简单使用示例--小型学生管理系统 文章说明代码效果展示 文章说明 本文主要为了简单学习IndexDB数据库的使用&#xff0c;写了一个简单的增删改查功能 代码 App.vue&#xff08;界面的源码&#xff09; <template><div style"padding: 30px&…

红队内网攻防渗透:内网渗透之内网对抗:横向移动篇域控系统提权NetLogonADCSPACKDC永恒之蓝CVE漏洞

红队内网攻防渗透 1. 内网横向移动1.1 横向移动-域控提权-CVE-2020-1472 NetLogon1.2 横向移动-域控提权-CVE-2021-422871.3 横向移动-域控提权-CVE-2022-269231.4 横向移动-系统漏洞-CVE-2017-01461.5 横向移动-域控提权-CVE-2014-63241. 内网横向移动 1、横向移动-域控提权-…

elementui组件库实现电影选座面板demo

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Cinema Seat Selection</title><!-- 引入E…

【学一点儿前端】单页面点击前进或后退按钮导致的内存泄露问题(history.listen监听器清除)

今天测试分配了一个比较奇怪的问题&#xff0c;在单页面应用中&#xff0c;反复点击“上一步”和“下一步”按钮时&#xff0c;界面表现出逐渐变得卡顿。为分析这一问题&#xff0c;我用Chrome的性能监控工具进行了浏览器性能录制。结果显示&#xff0c;每次点击“上一步”按钮…

区间预测 | Matlab实现CNN-ABKDE卷积神经网络自适应带宽核密度估计多变量回归区间预测

区间预测 | Matlab实现CNN-ABKDE卷积神经网络自适应带宽核密度估计多变量回归区间预测 目录 区间预测 | Matlab实现CNN-ABKDE卷积神经网络自适应带宽核密度估计多变量回归区间预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现CNN-ABKDE卷积神经网络自适应…

思考题:相交的几何图形

给定不超过 26 个几何图形&#xff0c;每个图形都有一个唯一大写字母作为其编号。 每个图形在平面中的具体位置已知&#xff0c;请你判断&#xff0c;对于每个图形&#xff0c;有多少个其他图形与其存在交点。 在判断交点时&#xff0c;只考虑边与边相交的情况&#xff0c;如…

Java 8 Date and Time API

Java 8引入了新的日期和时间API&#xff0c;位于java.time包下&#xff0c;旨在替代旧的java.util.Date和java.util.Calendar类。新API更为简洁&#xff0c;易于使用&#xff0c;并且与Joda-Time库的一些理念相吻合。以下是Java 8 Date and Time API中几个核心类的简要概述&…

AIGC-CVPR2024best paper-Rich Human Feedback for Text-to-Image Generation-论文精读

Rich Human Feedback for Text-to-Image Generation斩获CVPR2024最佳论文&#xff01;受大模型中的RLHF技术启发&#xff0c;团队用人类反馈来改进Stable Diffusion等文生图模型。这项研究来自UCSD、谷歌等。 在本文中&#xff0c;作者通过标记不可信或与文本不对齐的图像区域&…

【网络协议】精讲ARP协议工作原理!图解超赞超详细!!!

亲爱的用户&#xff0c;打开微信&#xff0c;搜索公众号&#xff1a;“风云说通信”&#xff0c;即可免费阅读该文章~~ 目录 前言 1. ARP协议介绍 1.1 ARP协议功能 1.2 ARP请求报文 1.3 ARP工作原理 2. ARP 缓存超时 2.1 RARP 3. ARP 攻击 3.1 ARP 攻击分类 前言 首先…

HTML(16)——边距问题

清楚默认样式 很多标签都有默认的样式&#xff0c;往往我们不需要这些样式&#xff0c;就需要清楚默认样式 写法&#xff1a; 用通配符选择器&#xff0c;选择所有标签&#xff0c;清除所有内外边距选中所有的选择器清楚 *{ margin:0; padding:0; } 盒子模型——元素溢出 作…

超越AnimateAnyone, 华中科大中科大阿里提出Unimate,可以根据单张图片和姿势指导生成视频。

阿里新发布的UniAnimate&#xff0c;与 AnimateAnyone 非常相似&#xff0c;它可以根据单张图片和姿势指导生成视频。项目核心技术是统一视频扩散模型&#xff0c;通过将参考图像和估计视频内容嵌入到共享特征空间&#xff0c;实现外观和动作的同步。 相关链接 项目&#xff1…

Eclipse使用TFS(Team Foundation Server) 超详细

Eclipse使用TFS 1、什么是TFS2、TFS和Git的区别3、签出代码4、签入代码4.1、签出以进行编辑4.2、修改本地代码4.3、签入挂起的更改4.4、签入 如果不能 签入挂起的更改&#xff0c;则先 签出以进行编辑如果 签入挂起的更改不可选中&#xff0c;则 如下操作 1、什么是TFS Team F…

fastadmin多语言切换设置

fastadmin版本&#xff1a;1.4.0.20230711 以简体&#xff0c;繁体&#xff0c;英文为例 一&#xff0c;在application\config.php 里开启多语言 // 是否开启多语言lang_switch_on > true, // 允许的语言列表allow_lang_list > [zh-cn, en,zh-tw], 二…

达梦数据守护集群部署

接上篇 达梦8单机规范化部署 https://blog.csdn.net/qq_25045631/article/details/139898690 1. 集群规划 在正式生产环境中&#xff0c;两台机器建议使用统一配置的服务器。使用千兆或千兆以上网络。 两台虚拟机各加一块网卡&#xff0c;仅主机模式&#xff0c;作为心跳网卡…

Notepad++插件 Hex-Edit

Nptepad有个Hex文件查看器&#xff0c;苦于每次打开文件需要手动开插件显示Hex&#xff0c;配置一下插件便可实现打开即调用 关联多个二进制文件&#xff0c;一打开就使用插件的方法&#xff0c;原来是使用空格分割&#xff01;&#xff01;&#xff01;

创新指南|品牌电商新策略:五大转型思路与RGM举措

在流量红利过去的背景下&#xff0c;品牌电商面对多渠道运营的难题&#xff0c;如缺乏统盘经营、绩效管理分散、价格战失控、用户体验不足以及流量过度依赖&#xff0c;品牌电商如何有效应对这些挑战&#xff0c;本文从5个维度探讨全渠道电商RGM破局之路&#xff0c;实现品牌的…