深度学习:迁移学习

目录

一、迁移学习

1.什么是迁移学习

2.迁移学习的步骤

1、选择预训练的模型和适当的层

2、冻结预训练模型的参数

3、在新数据集上训练新增加的层

4、微调预训练模型的层

5、评估和测试

二、迁移学习实例

1.导入模型

2.冻结模型参数

3.修改参数

4.创建类,数据增强,导入数据

5.定义训练集和测试集函数

6.将模型传入GPU,并有序调整学习率

7.进行训练和测试


一、迁移学习

1.什么是迁移学习

        迁移学习是指利用已经训练好的模型,在新的任务上进行微调。迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作。

 

2.迁移学习的步骤

1、选择预训练的模型和适当的层

        通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。

 

2、冻结预训练模型的参数

        保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。

 

3、在新数据集上训练新增加的层

        在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。

 

4、微调预训练模型的层

        在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。

 

5、评估和测试

        在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

 

二、迁移学习实例

  • 该实例使用的模型是ResNet-18残差神经网络模型

 

1.导入模型

  • 导入所要用的库,加载ResNet18模型
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np"""将resnet18模型迁移到食物分类项目中"""
resent_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)  # 既调用了resnet18网络,又使用了训练好的模型 在这里下载了模型

 

2.冻结模型参数

  • 将导入的模型参数冻结
for param in resent_model.parameters():param.requires_grad = False  # 设置每个参数的requires_grad属性为False,表示在训练过程中这些参数不需要计算梯度,也就是说它们不会在反向传播中更新。# print(param)
# 模型所有参数(即权重和偏差)的requires_grad属性设置为False,从而冻结所有模型参数
# 使得在反向传播过程中不会计算它们的梯度,以此减少模型的计算量,提高理速度。

 

3.修改参数

  • 因为我们所用的数据分类是20个,原模型分类是1000个,所以需要修改全连接层
  • 获取原模型输入层的特征个数
  • 将原模型的全连接层替换成原输入,输出为20的全连接层
  • 保存需要训练的参数,后面优化器进行优化时就可以只训练该层参数
in_features = resent_model.fc.in_features  # 获取模型原输入的特征个数
resent_model.fc = nn.Linear(in_features, 20)  # 创建一个全连接层,输入特征为in_features,输出为20param_to_update = []  # 保存需要训练的参数,仅仅包含全连接层的参数
for param in resent_model.parameters():if param.requires_grad == True:param_to_update.append(param)

 

4.创建类,数据增强,导入数据

  • 将图片从本地导入,并进行数据增强,最后进行打包
class food_dataset(Dataset):def __init__(self, file_path, transform=None):  # 类的初始化,解析数据文件txtself.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:  # 是把train.txt文件中图片的路径保存在 self.imgs,train.txt文件中标签保存在self.label里samples = [x.strip().split(' ') for x in f.readlines()]  # 去掉首尾空格 再按空格分成两个元素for img_path, label in samples:self.imgs.append(img_path)  # 图像的路径self.labels.append(label)  # 标签,还不是tensor# 初始化:把图片目录加载到selfdef __len__(self):  # 类实例化对象后,可以使用len函数测量对象的个数return len(self.imgs)def __getitem__(self, idx):  # 关键,可通过索引的形式获取每一个图片数据及标签image = Image.open(self.imgs[idx])  # 读取到图片数据,还不是tensorif self.transform:# 将pil图像数据转换为tensorimage = self.transform(image)  # 图像处理为256x256,转换为tenorlabel = self.labels[idx]  # label还不是tensorlabel = torch.from_numpy(np.array(label, dtype=np.int64))  # label也转换为tensorreturn image, labeldata_transforms = {'train':transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),transforms.CenterCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),# transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 为 ImageNet 数据集计算的标准化参数]),'test':transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 为 ImageNet 数据集计算的标准化参数])
}train_data = food_dataset(file_path=r'trainda.txt',transform=data_transforms['train'])  # 64张图片为一个包  训练集60000张图片 打包成了938个包
test_data = food_dataset(file_path=r'testda.txt', transform=data_transforms['test'])train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

 

5.定义训练集和测试集函数

def train(dataloader, model, loss_fn, optimizer):model.train()  # 告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w.在训练过程中,w会被修改的batch_size_num = 1for x, y in dataloader:x, y = x.to(device), y.to(device)  # 把训练数据集和标签传入CPU或GPUpred = model.forward(x)  # 向前传播loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值lossoptimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度值woptimizer.step()  # 根据梯度更新网络w参数loss_value = loss.item()  # 从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 40 == 0:print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1best_acc = 0def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 测试,w就不能再更新。test_loss, correct = 0, 0with torch.no_grad():  # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算所占用的消耗for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model.forward(x)test_loss += loss_fn(pred, y).item()  # test loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batches  # 能来衡量模型测试的好坏。correct /= size  # 平均的正确率print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}\n")acc_s.append(correct)loss_s.append(test_loss)if correct > best_acc:  # 保存正确率最大的那一次的模型best_acc = correct

 

6.将模型传入GPU,并有序调整学习率

from torch import nndevice = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_avaibale() else 'cpu'
model = resent_model.to(device)  # 为什么不需要加括号,之前是model = CNN().to(device) 因为 resnet_model 是对象不是类"""有序调整学习率"""
loss_fn = nn.CrossEntropyLoss()  # 处理多分类
optimizer = torch.optim.Adam(param_to_update, lr=0.001)  # 仅训练最后一层的参数
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 调整学习率

 

7.进行训练和测试

  • 选择训练100轮,每训练一轮,输出测试结果
epchos = 100
acc_s = []
loss_s = []
for t in range(epchos):print(f"Epoch {t + 1}\n--------------------------")train(train_dataloader, model, loss_fn, optimizer)scheduler.step()test(test_dataloader, model, loss_fn)
print('最优测试结果为:', best_acc)

输出:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/438603.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GAN|对抗| 生成器更新|判别器更新过程

如上图所示,生成对抗网络存在上述内容: 真实数据集;生成器;生成器损失函数;判别器;判别器损失函数;生成器、判别器更新(生成器和判别器就是小偷和警察的关系,他们共用的…

kubernetes基础操作(pod生命周期)

pod生命周期 一、Pod生命周期 我们一般将pod对象从创建至终的这段时间范围称为pod的生命周期,它主要包含下面的过程: ◎pod创建过程 ◎运行初始化容器(init container)过程 ◎运行主容器(main container&#xff…

记录一次病毒启动脚本

在第一次下载软件时,目录中配了一个使用说明,说是需要通过start.bat 这个文件来启动程序,而这个 start.bat 就是始作俑者: 病毒作者比较狡猾,其中start.bat 用记事本打开是乱码,但是可以通过将这个批处理…

spring揭秘24-springmvc02-5个重要组件

文章目录 【README】【1】HanderMapping-处理器映射容器【1.1】HanderMapping实现类【1.1.1】SimpleUrlHandlerMapping 【2】Controller(二级控制器)【2.1】AbstractController抽象控制器(控制器基类) 【3】ModelAndView(模型与视…

java入门基础(一篇搞懂)

​ 如果您觉得这篇文章对您有帮助的话 欢迎您分享给更多人哦 感谢大家的点赞收藏评论,感谢您的支持!!! 首先给大家推荐比特博哥,java入门安装的JDk和IDEA社区版的安装视频 JDK安装与环境变量的配置 IDEA社区的安装与使…

帝国CMS系统开启https后,无法登陆后台的原因和解决方法

今天本地配置好了帝国CMS7.5,传去服务器后,使用http访问一切正常。但是当开启了https(SSL)后,后台竟然无法登陆进去了。 输入账号密码后,点击登陆,跳转到/e/admin/ecmsadmin.php就变成页面一片…

SpringBoot基础(三):Logback日志

SpringBoot基础系列文章 SpringBoot基础(一):快速入门 SpringBoot基础(二):配置文件详解 SpringBoot基础(三):Logback日志 目录 一、日志依赖二、日志格式1、记录日志2、默认输出格式3、springboot默认日志配置 三、日志级别1、基础设置2、…

golang-基础知识(流程控制)

1 条件判断if和switch 所有的编程语言都有这个if,表示如果满足条件就做某事,不满足就做另一件事,go中的if判断和其它语言的区别主要有以下两点 1. go里面if条件判断不需要括号 2. go的条件判断语句中允许声明一个变量,这个变量…

FPGA-UART串口接收模块的理解

UART串口接收模块 背景 在之前就有写过关于串口模块的文章——《串口RS232的学习》。工作后很多项目都会用到串口模块,又来重新理解一下FPGA串口接收的代码思路。 关于串口相关的参数,以及在文章《串口RS232的学习》中已有详细的描述,这里就…

单调队列与单调栈<2>——单调栈

单调栈的定义 单调递增栈 栈中元素从栈底到栈顶是递增的。 单调递减栈 栈中元素从栈底到栈顶是递减的。 单调栈的核心内容 我们从左到右遍历元素,构造单调栈(从栈顶到栈底递增或减):在 i 从左往右遍历的过程中,我…

手写堆排序

手写堆排序 摘要:本文记录使用go语言实现堆排序 堆的构建 堆性质: 对于每个小堆,父节点与两个子节点比较,父节点比左子节点大,也比右子节点大。 有五个数: 1,2,3,4,5 分别进行入栈。过程如下 (1) 堆为…

(作业)第三期书生·浦语大模型实战营(十一卷王场)--书生入门岛通关第3关Git 基础知识

任务编号 任务名称 任务描述 1 破冰活动 提交一份自我介绍。 2 实践项目 创建并提交一个项目。 破冰活动 提交一份自我介绍。 每位参与者提交一份自我介绍。 提交地址:https://github.com/InternLM/Tutorial 的 camp3 分支~ 安装并设置git 克隆仓库并…

[深度学习][python]yolov11+deepsort+pyqt5实现目标追踪

【算法介绍】 YOLOv11、DeepSORT和PyQt5的组合为实现高效目标追踪提供了一个强大的解决方案。 YOLOv11是YOLO系列的最新版本,它在保持高检测速度的同时,通过改进网络结构、优化损失函数等方式,提高了检测精度,能够同时处理多个尺…

CSS选择器的全面解析与实战应用

CSS选择器的全面解析与实战应用 一、基本选择器1.1 通配符选择器(*)2.标签选择器(div)1.3 类名选择器(.class)4. id选择器(#id) 二、 属性选择器(attr)三、伪…

欧几里得算法--(密码学基础)

根基:gcd(a,b)gcd(b,a mod b) 先举个例子吧,gcd(16,6)gcd(6,4)gcd(4,2)gcd(2,0)2 学习这个定理的时候我想了几个问题. 第一个问题:为什么求出的就一定是他们两个数的公约数? 这个问题很简单我们只需要通过几何来计较即可&#x…

Electron 使⽤ electron-builder 打包应用

electron有几种打包方式,我使用的是electron-builder。虽然下载依赖的时候让我暴躁,使用起来也很繁琐,但是它能进行很多自定义,打包完成后的体积也要小一些。 安装electron-builder: npm install electron-builder -…

python基础语法2

文章目录 1.顺序语句2.条件语句2.1 语法格式 3.缩进与代码块4.空语句 pass5.循环语句5.1 while循环5.2 for循环 5.3 continue与break 1.顺序语句 默认情况下,python的代码都是按照从上到下的顺序依次执行的。 print(hello ) print(world)结果一定是hello world。写…

【AIGC】ChatGPT提示词解析:如何打造个人IP、CSDN爆款技术文案与高效教案设计

博客主页: [小ᶻZ࿆] 本文专栏: AIGC | ChatGPT 文章目录 💯前言💯打造个人IP爆款文案提示词使用方法 💯CSDN爆款技术文案提示词使用方法 💯高效教案设计提示词使用方法 💯小结 💯前言 在这…

zookeeper 服务搭建(集群)

准备3台虚拟机,ip分别是: 192.168.10.75 192.168.10.76 192.168.10.77 准备3个节点 mkdir /usr/local/cluster cd /usr/local/cluster git clone https://gitee.com/starplatinum111/apache-zookeeper-3.5.9-bin.git 重命名文件夹 mv apache-zookeeper…

【学习笔记】手写一个简单的 Spring IOC

目录 一、什么是 Spring IOC? 二、IOC 的作用 1. IOC 怎么知道要创建哪些对象呢? 2. 创建出来的对象放在哪儿? 3. 创建出来的对象如果有属性,如何给属性赋值? 三、实现步骤 1. 创建自定义注解 2. 创建 IOC 容器…