(动手学习深度学习)第13章 计算机视觉---微调

文章目录

    • 微调
      • 总结
    • 微调代码实现

微调

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

总结

  • 微调通过使用在大数据上的恶道的预训练好的模型来初始化模型权重来完成提升精度。
  • 预训练模型质量很重要
  • 微调通常速度更快、精确度更高

微调代码实现

  1. 导入相关库
%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import matplotlib as plt
  1. 获取数据集
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'test'))
print(train_imgs)
print(train_imgs[0])
train_imgs[0][0]

在这里插入图片描述
查看数据集中图像的形状

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs= [train_imgs[-i-1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2 ,8, scale=1.4)

在这里插入图片描述

  1. 数据增强
# 图像增广
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225]
)
train_augs = torchvision.transforms.Compose(  # 训练集数据增强[torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize]
)
test_augs = torchvision.transforms.Compose(  # 验证集不做数据增强[torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize]
)
  1. 定义和初始化模型
# 下载resnet18,
# 老:pretrain=True: 也下载预训练的模型参数
# 新:weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
pretrained_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
print(pretrained_net.fc)

在这里插入图片描述

  1. 微调模型
  • (1)直接修改网络层(如最后全连接层:512—>1000,改成512—>2)
  • (2)在增加一层分类层(如:512—>1000, 改成512—>1000, 1000—>2)

本次选择(1):将resnet18最后全连接层的输出,改成自己训练集的类别,并初始化最后全连接层的权重参数

finetune_net = pretrained_net
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight)

在这里插入图片描述

print(finetune_net)

在这里插入图片描述

  1. 训练模型
  • 特征提取层(预训练层):使用较小的学习率
  • 输出全连接层(微调层):使用较大的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=10, param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'), transform=train_augs),batch_size=batch_size,shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)device = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction='none')if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ['fc.weight', 'fc.bias']]trainer = torch.optim.SGD([{'params': params_1x}, {'params': net.fc.parameters(), 'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(),lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss,trainer, num_epochs, device)

训练模型

import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train_fine_tuning(finetune_net, 5e-5, 128, 10)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')

在这里插入图片描述

直接训练:整个模型都使用相同的学习率,重新训练

scracth_net = torchvision.models.resnet18()
scracth_net.fc = nn.Linear(scracth_net.fc.in_features, 2)import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train_fine_tuning(scracth_net, 5e-4, param_group=False)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/198750.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux网络——HTTP

目录 一.应用层 二.认识URL 1.域名 2.urlencode和urldecode 三.HTTP协议格式 1.请求格式http 2.响应格式 四.HTTP响应状态码 五.HTTP常见Header 六.简单的HTTP服务器 七.HTTP的方法 1.GET方法 2.POST方法 一.应用层 我们程序员写的一个个解决我们实际问题, 满足…

深入分析高性能互连点对点通信开销

今天分享最近阅读的一篇文章:“Breaking Band,A Breakdown of High-Performance Communication”,这篇文章发表在ICPP 2019会议。由加州大学欧文分校和ARM公司合作完成。从题目中可以看到,这篇文章是一篇评测型的文章,…

JDY蓝牙注意事项

波特率设置:9600,不接受115200,或者38400. 不同于WiFi测试,jdy蓝牙不接受AT"指令,可以使用“ATVERSION"指令测试 安信可公司的那个蓝牙指令在这里没有用,不知道是不是生产的公司不一样

网络协议入门 笔记一

一、服务器和客户端及java的概念 JVM (Java Virtual Machine) : Java虚拟机,Java的跨平台:一次编译,到处运行,编译生成跟平台无关的字节码文件 (class文件),由对应平台的JVM解析字节码为机器指令 (010101)。 如下图所示&#xff0…

Java入门篇 之 抽象类接口

本篇碎碎念:个人认为压力是一种前进的动力,但是不要有太多压力,不然会使心情烦躁,会控制不住自己的情绪,会在一个临界值爆发,一旦爆发,将迟迟不能消散 今日份励志文案: 努力的背后必有加倍的赏赐…

Word中NoteExpress不显示的问题

首先确认我们以及安装了word插件 我们打开word却没有。此时我们打开:文件->选项->加载项 我们发现被禁用了 选择【禁用项目】(如果没有,试一试【缓慢且禁用的加载项】),点击转到 选择启用 如果没有禁用且没有出…

海康威视综合安防管理平台任意文件上传

系统介绍 HIKVISION iSecure Center综合安防管理平台是一套“集成化”、“智能化”的平台,通过接入视频监控、一卡通、停车场、报警检测等系统的设备,获取边缘节点数据,实现安防信息化集成与联动,公众号:web安全工具库…

使用centos搭建内网的yum源

1.安装httpd服务 2.启动服务,设置开机自启 #启动服务 systemctl start httpd # 设置开机自动启动 systemctl enable httpd systemctl status httpd3.新建一个目录,将rpm文件放到该目录下 4.将/etc/httpd/conf/httpd.conf文件中的DocumentRoot "…

SASS/SCSS精华干货教程

目录 介绍 基本说明 特点 sass语法格式sass的语法格式一共有两种,一种是以".scss"作为拓展名,一种是以".sass"作为拓展名,这里我们只讲拓展名: 编译环境安装 Vscode安装编译插件 简单使用 sass语法扩张…

合并两个有序链表(冒泡排序实现)

实例要求:将两个升序链表合并为一个新的 升序 链表并返回;新链表是通过拼接给定的两个链表的所有节点组成的;实例分析:先拼接两个链表,在使用冒泡排序即可;示例代码: struct ListNode* mergeTwo…

docker的基本使用以及使用Docker 运行D435i

1.一些基本的指令 1.1 容器 要查看正在运行的容器&#xff1a; sudo docker ps 查看所有的容器&#xff08;包括停止状态的容器&#xff09; sudo docker ps -a 重新命名容器 sudo docker rename <old_name> <new_name> <old_name> 替换为你的容器名称…

查询数据库DQL

DQL 查询基本语法 -- DQL :基本语法; -- 1查询指定的字段 name entrydate 并返回select name , entrydate from tb_emp;-- 2 查询 所有字段 并返回select id, username, password, name, gender, image, job, entrydate, create_time, update_time from tb_emp;-- 2 查询…

C++ 继承和派生 万字长文超详解

本文章内容来源于C课堂上的听课笔记 继承和派生基础 继承是一种概念&#xff0c;它允许一个新创建的类&#xff08;称为子类或派生类&#xff09;获取另一个已经存在的类&#xff08;称为父类或基类&#xff09;的属性和行为。这就好比是子类继承了父类的特征。想象一下&…

【dc-dc】世微 电动车摩托车灯 5-80V 1.2A 一切二降压恒流驱动器AP2915

产品描述 AP2915 是一款可以一路灯串切换两路灯串的降压恒流驱动器,高效率、外围简单、内置功率管&#xff0c;适用于5-80V 输入的高精度降压 LED 恒流驱动芯片。内置功率管输出最大功率可达 12W&#xff0c;最大电流 1.2A。AP2915 一路灯亮切换两路灯亮&#xff0c;其中一路灯…

微积分在神经网络中的本质

calculus 在一个神经网络中我们通常将每一层的输出结果表示为&#xff1a; a [ l ] a^{[l]} a[l] 为了方便记录&#xff0c;将神经网络第一层记为&#xff1a; [ 1 ] [1] [1] 对应的计算记录为为&#xff1a; a [ l ] &#xff1a; 第 l 层 a [ j ] &#xff1a; 第 j 个神经…

在市场发展中寻变革,马上消费金融树行业发展“风向标”

11月11日&#xff0c;2023金融街论坛年会第三届全球金融科技大会“金融科技创新与合规安全”平行论坛在北京召开。会上&#xff0c;马上消费金融副总经理孙磊就数据对金融的赋能作用、数据安全治理等方面展开了深度讨论。 公开信息显示&#xff0c;马上消费金融是一家经中国银保…

如何将文字、图片、视频、链接等内容生成一个二维码?

通过二维彩虹的【H5编辑】功能&#xff0c;就可以将文字、图片、视频、文件、链接等多种格式的内容编辑在一个页面&#xff0c;然后生成一个自定义的二维码——H5编辑二维码。扫描后&#xff0c;即可查看二维码中的详细图文视频等内容了。这个功能大受欢迎&#xff01; 这个H5…

大师学SwiftUI第18章Part1 - 图片选择器和相机

如今&#xff0c;个人设备主要用于处理图片、视频和声音&#xff0c;苹果的设备也不例外。SwiftUI可以通过​​Image​​视图显示图片&#xff0c;但需要其它框架的支持来处理图片、在屏幕上展示视频或是播放声音。本章中我们将展示Apple所提供的这类工具。 图片选择器 Swift…

Three.js相机模拟

有没有想过如何在 3D Web 应用程序中模拟物理相机? 在这篇博文中,我将向你展示如何使用 Three.js和 OpenCV 来完成此操作。 我们将从模拟针孔相机模型开始,然后添加真实的镜头畸变。 具体来说,我们将仔细研究 OpenCV 的两个失真模型,并使用后处理着色器复制它们。 拥有逼…

Rockdb简介

背景 最近在使用flink的过程中&#xff0c;由于要存储的状态很大&#xff0c;所以使用到了rockdb作为flink的后端存储&#xff0c;本文就来简单看下rockdb的架构设计 Rockdb设计 Rockdb采用了LSM的结构&#xff0c;它和hbase很像&#xff0c;不过严格的说&#xff0c;基于LS…