3.2.微调

微调

​ 对于一些样本数量有限的数据集,如果使用较大的模型,可能很快过拟合,较小的模型可能效果不好。这个问题的一个解决方案是收集更多数据,但其实在很多情况下这是很难做到的。

​ 另一种方法就是迁移学习(transfer learning),将源数据集学到地知识迁移到目标数据集,例如,我们只想识别椅子,只有100把椅子,每把椅子的1000张不同角度的图像,尽管ImageNet数据集中大多数图像与椅子无关,但在次数据集上训练的模型可能会提取更通用的图像特征(可以理解为越底层的layer提取的特征越通用),这有助于识别边缘、纹理、形状和对象组合,也可能有效地识别椅子。

在这里插入图片描述

1. 步骤

​ 微调是迁移学习中的常见技巧,步骤如下:

  1. 在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型
  2. 创建一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数(输出层除外)。我们假定这些模型参数包含从源数据集中学到的知识,这些知识也将适用于目标数据集。我们还假设源模型的输出层与源数据集的标签密切相关;因此不在目标模型中使用该层。
  3. 向目标模型添加输出层,其输出数是目标数据集中的类别数。然后随机初始化该层的模型参数。
  4. 在目标数据集(如椅子数据集)上训练目标模型。输出层将从头开始进行训练,而所有其他层的参数将根据源模型的参数进行微调。

在这里插入图片描述

1.1 目标模型的训练:

​ 是一个正在目标数据集上的正常训练任务,但使用更强的正则化(参数变化不大):

  • 更小的学习率
  • 更少的数据迭代

​ 如果源数据集远复杂于目标数据,通常微调效果更好

1.2 重用分类器权重

​ 有些时候源数据集可能也有目标数据中的部分标号,比如ImageNet里可能有椅子这一标签,那么可以使用预训练好的模型分类器中对应标号中对应的向量来做初始化(就直接copy)

1.3 固定一些层

​ 通常而言,神经网络中低层次的特征更加通用,高层次的特征则更跟数据集相关。

​ 那么可以固定底部一些层的参数,不参与更新,这样也能有更强的正则。

2.热狗识别

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import torch_directmldevice = torch_directml.device()
# @save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)
d2l.plt.show()# 使用RGB通道的均值和标准差,以标准化每个通道 ,因为预训练的模型做了这个
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])# 先随机裁剪,并变为224 * 224 的图形,因为预训练模型输入是这个
train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])
# 将图像的高度和宽度都缩放到256像素,然后裁剪中央 224 * 224的区域来作为输入
test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])# 下载模型,pretrained参数已被弃用,使用weights来获取与训练模型
pretrained_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
print(pretrained_net.fc)  # 预训练最后一层为输出层finetune_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1).to(device)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2).to(device)
nn.init.xavier_uniform_(finetune_net.fc.weight)def train_batch_ch13(net, X, y, loss, trainer, devices):"""用多GPU进行小批量训练"""if isinstance(X, list):# 微调BERT中所需X = [x.to(devices) for x in X]else:X = X.to(devices)y = y.to(devices)net.train()trainer.zero_grad()pred = net(X)l = loss(pred, y)l.sum().backward()trainer.step()train_loss_sum = l.sum()train_acc_sum = d2l.accuracy(pred, y)return train_loss_sum, train_acc_sum# @save 多GPU的,把参数devices改成device了,本来是个列表
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,device):"""用多GPU进行模型训练"""timer, num_batches = d2l.Timer(), len(train_iter)animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],legend=['train loss', 'train acc', 'test acc'])# net = nn.DataParallel(net, device_ids=devices).to(devices[0])for epoch in range(num_epochs):# 4个维度:储存训练损失,训练准确度,实例数,特点数metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_iter):timer.start()l, acc = train_batch_ch13(net, features, labels, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[3],None))test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc 'f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on 'f'{str(device)}')# 微调
# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)# devices = d2l.try_all_gpus()device = torch_directml.device()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]# 最后一层使用10倍学习率trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,device)train_fine_tuning(finetune_net, 5e-5)
d2l.plt.show()

loss 0.270, train acc 0.899, test acc 0.948
232.6 examples/sec on privateuseone:0

一次训练效果就很好了,而且后续训练很平滑,没有过拟合。

在这里插入图片描述

​ 如果初始化为随机值:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/386927.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++如何理解多态与虚函数

目录 **前言****1. 何为多态**1.1 **编译时多态**1.1.1 函数重载1.1.2 模板 **1.2 运行时多态****1.2.1 虚函数****1.2.2 为什么要用父类指针去调用子类函数** **2. 注意****2.1 基类的析构函数应写为虚函数****2.2 构造函数不能设为虚函数** **本文参考** 前言 在学习 c 的虚…

打造重庆市数字化教育“新名片”,广阳湾珊瑚中学凭实力“出圈”!

分布于教学楼连廊顶部的智能照明设备,根据不同的时间和场景需求自动调节灯光亮度和开关状态;安装于各个教室内的智能黑板、学校同步时钟、学生互动设备,在极简以太全光网的赋能下,为师生提供丰富的教学体验与学习支持......行走于重庆市广阳湾珊瑚中学,像是与充满科技感的“校园…

病理AI领域的基础模型汇总|顶刊专题汇总·24-07-26

小罗碎碎念 本期文献主题:病理AI领域的最新基础模型 今天的推文是一期生日特辑,定时在下午六点二十一分发表(今天农历六月二十一,哈哈),算是自己给自己的24岁生日礼物,希望24岁这一年&#xff0…

ollama本地部署大语言模型记录

目录 安装Ollama更改模型存放位置 拉取模型GemmaMistralQwen1.5(通义千问)codellama 部署Open webui测试性能知识广度问题1问题2 代码能力总结 最近突然对大语言模型感兴趣 同时在平时的一些线下断网的CTF比赛中,大语言模型也可以作为一个能对话交互的高级知识检索…

SSRF中伪协议学习

SSRF常用的伪协议 file:// 从文件系统中获取文件内容,如file:///etc/passwd dict:// 字典服务协议,访问字典资源,如 dict:///ip:6739/info: ftp:// 可用于网络端口扫描 sftp:// SSH文件传输协议或安全文件传输协议 ldap://轻量级目录访问协议 tftp:// 简单文件传输协议 gopher…

【JavaScript】函数声明和函数表达式的区别

文章目录 一、函数声明1. 定义方式2. 作用域提升(Hoisting)3. 块级作用域 二、函数表达式1. 定义方式2. 作用域提升(Hoisting)3. 自引用 三、其他区别1. 函数名2. 可读性和代码组织3. 使用场景 四、总结函数声明函数表达式 在Java…

【大模型系列】Video-LaVIT(2024.06)

Paper:https://arxiv.org/abs/2402.03161Github:https://video-lavit.github.io/Title:Video-LaVIT: Unified Video-Language Pre-training with Decoupled Visual-Motional TokenizationAuthor:Yang Jin, 北大&#x…

Java面试八股之@Qualifier的作用

Qualifier的作用 Qualifier 是 Spring 框架中的一个非常有用的注解,它主要用于解决在依赖注入过程中出现的歧义问题。当 Spring 容器中有多个相同类型的 Bean 时,Qualifier 可以帮助指明应该使用哪一个具体的 Bean 进行注入。 Qualifier 的作用&#x…

外设购物平台

目 录 一、系统分析 二、系统设计 2.1 系统功能设计 2.2 数据库设计 三、系统实现 3.1 注册功能 3.2 登录功能 3.3 分页查询所有商品信息功能 3.4 分页条件(精确、模糊)查询商品信息功能 3.5 购物车功能 3.6 订单管理功能 四、项…

【Opencv】模糊

消除噪声 用该像素周围的平均值代替该像素值 4个函数 blur():最经典的 import os import cv2 img cv2.imread(os.path.join(.,dog.jpg)) k_size 7 #窗口大小,数字越大,模糊越强 img_blur cv2.blur(img,(k_size,k_size)) #窗口是正方形&#xff…

云计算实训16——关于web,http协议,https协议,apache,nginx的学习与认知

一、web基本概念和常识 1.Web Web 服务是动态的、可交互的、跨平台的和图形化的为⽤户提供的⼀种在互联⽹上浏览信息的服务。 2.web服务器(web server) 也称HTTP服务器(HTTP server),主要有 Nginx、Apache、Tomcat 等。…

C#使用csvhelper实现csv的操作

新建控制台项目 安装csvhelper 33.0.1 写入csv 新建Foo.cs namespace CsvSut02;public class Foo {public int Id { get; set; }public string Name { get; set; } }批量写入 using System.Globalization; using CsvHelper; using CsvHelper.Configuration;namespace Csv…

[数据集][目标检测]金属罐缺陷检测数据集VOC+YOLO格式8095张4类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):8095 标注数量(xml文件个数):8095 标注数量(txt文件个数):8095 标注…

使用Process Explorer和Dependency Walker排查dll动态库加载失败的问题

目录 1、问题描述 2、如何调试Release版本的代码? 3、使用Process Explorer查看exe主程序加载的dll库列表,发现mediaplay.dll没有加载起来 4、使用Dependency Walker查看rtcmpdll.dll的库依赖关系和接口调用情况,定位问题 4.1、使用Depe…

Javascript面试基础6【每日更新10】

Gulp gulp是前端开发过程中一种基于流的代码构建工具,是自动化项目的构建利器;它不仅能对网站资源进行优化,而且在开发过程中很多重复的任务能够使用正确的工具自动完成 Gulp的核心概念:流 流,简单来说就是建立在面向对象基础上的一种抽象的…

多微信聚合神器:高效沟通,一个界面全搞定!

大家都知道,频繁的来回切换微信,不仅浪费时间,还容易错过重要的信息。 今天,我要向大家推荐一款多微信管理神器——个微管理系统,助你实现统一管理,聚合聊天,让沟通变得更加高效。 1、网页扫码…

基于MindIE实现通义千问Qwen推理加速

一、昇腾开发者平台申请镜像 登录Ascend官网昇腾社区-官网丨昇腾万里 让智能无所不及 二、登录并下载mindie镜像 #登录docker login -u XXX#密码XXX#下载镜像docker pull XXX 三、下载Qwen的镜像 使用wget命令下载Qwen1.5-0.5B-Chat镜像,放在/mnt/Qwen/Qwen1.5-…

【无标题】Git(仓库,分支,分支冲突)

Git 一种分布式版本控制系统,用于跟踪和管理代码的变更 一.Git的主要功能: 二.准备git机器 修改静态ip,主机名 三.git仓库的建立: 1.安装git [rootgit ~]# yum -y install git 2.创建一个…

【策略工厂模式】记录策略工厂模式简单实现

策略工厂模式 1. 需求背景2. 代码实现2.1 定义基类接口2.2 排序策略接口定义2.3 定义抽象类,实现策略接口2.4 具体的排序策略实现类2.5 实现策略工厂类2.6 控制类 3. 启动测试4. 总结 1. 需求背景 现在需要你创建一个策略工厂类,来根据策略实现各种排序…

【JAVA】记录一次前端无能造成的 线上bug

有一个需求是 当方式切换 垫资时 清空 当前所选细单商品 但是前端的奇葩 操作是,只是在页面上清空 细单。 不请求 后台删除 细单 让前端 必须 清空同时 请求后台 删除细单 但是 该前端 技术不行, 嫌麻烦 不做 只好 后台 判断该类型时 进行删除操作…