微调及代码

一、微调:迁移学习(transfer learning)将从源数据集学到的知识迁移到目标数据集

二、步骤

1、在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型

2、创建一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数(输出层除外)。

3、向目标模型添加输出层,其输出数是目标数据集中的类别数。然后随机初始化该层的模型参数。

4、在目标数据集(如椅子数据集)上训练目标模型。输出层将从头开始进行训练,而所有其他层的参数将根据源模型的参数进行微调。

5、目标数据集比源数据集小得多时,微调有助于提高模型的泛化能力。就相当于在别人训练好的基础上训练

三、网络架构:神经网络一般分为两块:特征抽取(将原始像素变成容易线性分割的特征)和线性分类器

四、训练

1、微调是在目标数据集上的正常训练任务,但使用更强的正则化(有更强的lr和更少的epoch)

2、源数据集远复杂于目标数据,通常微调效果会更好

3、源数据集中可能也有目标数据中的部分标号

4、固定底层训练高层

五、总结

1、迁移学习将从源数据集中学到的知识迁移到目标数据集,微调是迁移学习的常见技巧。

2、除输出层外,目标模型从源模型中复制所有模型设计及其参数,并根据目标数据集对这些参数进行微调。但是,目标模型的输出层需要从头开始训练。

3、通常,微调参数使用较小的学习率,而从头开始训练输出层可以使用更大的学习率。

六、代码

1、导入数据

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])

2、定义和初始化模型

pretrained_net = torchvision.models.resnet18(pretrained=True)finetune_net = torchvision.models.resnet18(pretrained=True)
#最后一层类别为2
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

3、微调模型

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/375217.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python基础篇(9):模块

1 模块简介 Python 模块(Module),是一个 Python 文件,以 .py 结尾. 模块能定义函数,类和变量,模块里也能包含可执行的代码. 模块的作用: python中有很多各种不同的模块, 每一个模块都可以帮助我们快速的实现一些功能, 比如实现…

概论(二)随机变量

1.名词解释 1.1 样本空间 一次具体实验中所有可能出现的结果,构成一个样本空间。 1.2 随机变量 把结果抽象成数值,结果和数值的对应关系就形成了随机变量X。例如把抛一次硬币的结果,正面记为1,反面记为0。有变量相对应的就有自…

SpringBoot实战:轻松实现接口数据脱敏

一、接口数据脱敏概述 1.1 接口数据脱敏的定义 接口数据脱敏是Web应用程序中一种保护敏感信息不被泄露的关键措施。在API接口向客户端返回数据时,系统会对包含敏感信息(如个人身份信息、财务数据等)的字段进行特殊处理。这种处理通过应用特…

多个版本JAVA切换(学习笔记)

多个版本JAVA切换 很多时候,我们电脑上会安装多个版本的java版本,java8,java11,java17等等,这时候如果想要切换java的版本,可以按照以下方式进行 1.检查当前版本的JAVA 同时按下 win r 可以调出运行工具…

WMS系统的核心功能

WMS系统(Warehouse Management System)的核心功能主要包括以下几个方面: ———————————————————————— 1、库存管理: 1):跟踪库存数量、位置和状态,确保实时库存可见性。 2):支持批次管理、序列…

文心快码——百度研发编码助手

介绍 刚从中国互联网大会中回来,感受颇深吧。百度的展商亮相了文心快码,展商人员细致的讲解让我们一行了解到该模型的一些优点。首先,先来简单介绍一下文心快码吧。 文心快码(ERNIE Code)是百度公司推出的一个预训练…

【STM32标准库】读写内部FLASH

1.内部FLASH的构成 STM32F407的内部FLASH包含主存储器、系统存储器、OTP区域以及选项字节区域。 一般我们说STM32内部FLASH的时候,都是指这个主存储器区域,它是存储用户应用程序的空间。STM32F407ZGT6型号芯片, 它的主存储区域大小为1MB。其…

ppt翻译免费怎么做?5个方法让你秒懂PPT的内容

当你收到一份来自海外的PPT资料,眼前或许是一片陌生的语言海洋,但别让这成为理解与灵感之间的障碍。 这时,一款优秀的PPT翻译软件就如同你的私人导航员,能迅速将这份知识宝藏转化为你熟悉的语言,让每一个图表、每一段…

Unity引擎制作玻璃的反射和折射效果

Unity引擎制作玻璃球玻璃杯 大家好,我是阿赵。   之前做海面效果的时候,没做反射和折射的效果,因为我觉得过于复杂的效果没有太大的实际作用。这方面的效果,我就做了现在这个例子来补充一下。 在这个demo场景里面,我…

社交媒体数据分析:赋能企业营销策略的利器

一、数据:未来的石油与导航仪 在数字化转型的大潮中,数据已成为推动企业发展的新燃料。它不仅是决策的依据,更是预见未来的水晶球。特别是在社交媒体这片广袤的海洋里,每一条帖子、每一次点赞、评论都蕴藏着消费者的偏好、市场的…

thinkphp8框架源码精讲

前言 很开心你能看到这个笔记,相信你对thinkphp是有一定兴趣的,正好大家都是志同道合的人。 thinkphp是我入门学习的第一个框架,经过这么多年了,还没好好的研究它,今年利用了空闲的时间狠狠的深入源码学习了一把&…

Proteus + Keil单片机仿真教程(五)多位LED数码管的静态显示

Proteus + Keil单片机仿真教程(五)多位LED数码管 上一章节讲解了单个数码管的静态和动态显示,这一章节将对多个数码管的静态显示进行学习,本章节主要难点: 1.锁存器的理解和使用; 2.多个数码管的接线封装方式; 3.Proteus 快速接头的使用。 第一个多位数码管示例 元件…

Qt学生管理系统(付源码)

Qt学生管理系统 一、前言1.1 项目介绍1.2 项目目标 2、需求说明2.1 功能性说明2.2 非功能性说明 三、UX设计3.1 登录界面3.2 学生数据展示3.3 信息插入和更新 三、架构说明3.1 客户端结构如下3.2 数据流程图3.2.1 数据管理3.2.2 管理员登录 四、 设计说明3.1 数据库设计3.2 结构…

嵌入式要卷成下一个Java了吗?

嵌入式系统与Java的关系在技术发展和市场需求的影响下在逐步演变,但尚未达到完全替代的阶段。我收集归类了一份嵌入式学习包,对于新手而言简直不要太棒,里面包括了新手各个时期的学习方向编程教学、问题视频讲解、毕设800套和语言类教学&…

system V共享内存【Linux】

文章目录 原理shmgetftokshmat(share memory attach)shmdt,去关联(share memory delete attach)shmctl ,删除共享内存共享内存与管道 原理 共享内存本质让不同进程看到同一份资源。 申请共享内存: 1、操作系统在物理内存当中申请…

【鸿蒙学习笔记】通过用户首选项实现数据持久化

官方文档:通过用户首选项实现数据持久化 目录标题 使用场景第1步:源码第2步:启动模拟器第3步:启动entry第6步:操作样例2 使用场景 Preferences会将该数据缓存在内存中,当用户读取的时候,能够快…

从2024上半年《人工智能现状报告》看GPU前世今生

前不久,全球领先的低代码平台Retool发布了最新的2024上半年《人工智能现状报告》,这份报告收集了约750名技术人员的意见,包括开发者、数据团队和各行业的领导者。报告通过调研人们对AI产生的情绪变化、AI应用现状、AI使用率等等几个方面总结了…

上海慕尼黑电子展开展,启明智显携物联网前沿方案亮相

随着科技创新的浪潮不断涌来,上海慕尼黑电子展在万众瞩目中盛大开幕。本次展会汇聚了全球顶尖的电子产品与技术解决方案,成为业界瞩目的焦点。启明智显作为物联网彩屏显示领域的佼佼者携产品亮相展会,为参展者带来了RTOS、LINUX全系列方案及A…

HTML 基础

文章目录 HTML 结构认识 HTML 标签HTML 文件基本结构快速生成代码框架 HTML 常见标签注释标签标题标签: h1-h6段落标签: p换行标签: br格式化标签图片标签: img超链接标签: a表格标签列表标签表单标签form 标签input 标签 label 标签select 标签textarea 标签无语义标签: div &…

浏览器书签助手mTab

本文软件由网友 P家单推人 推荐 什么是 mTab ? mTab 是免费无广告的浏览器书签助手,多端同步、美观易用的在线导航和书签工具,可以用 mTab 书签收藏并自定义常用网站的图标样式,帮助您高效管理网页和应用,提升在线体验。 官方提供…