pytorch实战-图像分类(二)(模型训练及验证)(基于迁移学习(理解+代码))

目录

1.迁移学习概念

2.数据预处理

 3.训练模型(基于迁移学习)

3.1选择网络,这里用resnet

3.2如果用GPU训练,需要加入以下代码

3.3卷积层冻结模块

3.4加载resnet152模

3.5解释initialize_model函数

3.6迁移学习网络搭建

3.7优化器

3.8训练模块(可以理解为主函数)

3.9开始训练

3.10微调

4.测试模型

4.1加载训练好的模型

4.2测试数据预处理

4.3数据展示

4.4提取测试数据集

4.5计算提取数据集的预测结果

4.6展示预测结果

参考文献


1.迁移学习概念

先说一下深度学习常见的问题

        1.数据集不够,通常用数据增强解决。

        2.参数难以确定,训练时间长,这就需要用迁移学习来解决

什么叫迁移学习呢:比方说有一个对100w的自行车数据集,并用VGG模型训练好的网络,而此时你想训练一个1w自行车数据集(虽然对象一样,但采集的数据会不同),也用VGG模型进行训练,你发现,你们数据集的对象一样,选用的网络模型一样,此时在初始化自己模型权重(就是卷积层,池化层和全连接层的参数)时,可以用人家训练好的模型参数(如果不这样就需要随机初始化模型权重),这样做可以节省大量寻找最优参数的时间,又可以保证参数的准确。

总结:迁移学习就是用别人的东西训练自己的东西,但要注意,为了使用别人的模型参数,要保证自己的数据对象、网络结构、输入和输出数据的结构和别人相同。比方说,别人识别狗,你不能识别 猫,别人用VGG你不能用resnet,别人输入和输入图像大小是224×224.你不能是256×256。

进一步理解迁移学习的使用1:看下图最大的红框,表示卷积层,当用别人的模型时,对卷积层的两种处理方式。

        A:作为自己模型权重的初始化参数。

        B:冻结卷积层网络,意思是直接用别人的参数,不再更新。冻结卷积层网络又分几种情况。

                B1:当数据量小时,冻结第二大红框表示的卷积层,剩下卷积层进行更新。因为数据量小时,容易过拟合,直接用别人呢参数最好。

                B2:当数据量中等时冻结最小红框表示的卷积层,剩下的卷积层进行更行。

                B3:当数据量足够大时,不冻结卷积层,用A的方法,只作为自己模型权重的初始化参数。数据量大时,虽然对象一样,但毕竟数据不同,会有一定差异,更新参数是最优选择。

 进一步理解迁移学习的使用2:说完卷积层,在说一下全连接层,必须要注意不管卷积层选A还是B,全连接层都是要更新的,原因在于,别人模型进行图像分类可能是进行1000个分类,而你只进行100或者999个分类,那么全连接层的参数肯定是不同的。

2.数据预处理

上接该文:pytorch实战-图像分类(一)(数据预处理)

 3.训练模型(基于迁移学习)

3.1选择网络,这里用resnet

model_name = 'resnet'  #可选的比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']
#是否用人家训练好的特征来做
feature_extract = True 

3.2如果用GPU训练,需要加入以下代码

# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available.  Training on CPU ...')
else:print('CUDA is available!  Training on GPU ...')device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

3.3卷积层冻结模块

def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = False

3.4加载resnet152模

注意:resnet152模型就是别人的模型。

model_ft = models.resnet152()
model_ft

3.5解释initialize_model函数

本小节只是截取pytorch官网的一个例子,用initialize_model说明在pytoch中迁移学习怎么使用,不属于本文代码

具体操作如下

        1.下载别人的模型参数,这里下载restnet152模型

        2.选择需要冻结的卷积层

        3.改变全连接层的输出个数,这里将1000改为102

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):# 选择合适的模型,不同模型的初始化方法稍微有点区别model_ft = Noneinput_size = 0if model_name == "resnet":""" Resnet152"""model_ft = models.resnet152(pretrained=use_pretrained) #下载resnet152模型set_parameter_requires_grad(model_ft, feature_extract) #选择冻结哪部分卷积层num_ftrs = model_ft.fc.in_features #全连接层的输入比方说全连接层是2048×1000,这就是2048.model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102),nn.LogSoftmax(dim=1)) #原resnet152的全连接层输出是1000,自己模型需要的输出是102,进行改动。input_size = 224return model_ft, input_size

3.6迁移学习网络搭建

model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)#GPU计算
model_ft = model_ft.to(device)# 模型保存
filename='checkpoint.pth'# 是否训练所有层
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:params_to_update = []for name,param in model_ft.named_parameters():if param.requires_grad == True:params_to_update.append(param)print("\t",name)
else:for name,param in model_ft.named_parameters():if param.requires_grad == True:print("\t",name)

3.7优化器

就是用该方法更新模型参数

# 优化器设置
optimizer_ft = optim.Adam(params_to_update, lr=1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)#学习率每7个epoch衰减成原来的1/10
#最后一层已经LogSoftmax()了,所以不能nn.CrossEntropyLoss()来计算了,nn.CrossEntropyLoss()相当于logSoftmax()和nn.NLLLoss()整合
criterion = nn.NLLLoss()

3.8训练模块(可以理解为主函数)

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False,filename=filename):since = time.time() #best_acc = 0"""checkpoint = torch.load(filename)best_acc = checkpoint['best_acc']model.load_state_dict(checkpoint['state_dict'])optimizer.load_state_dict(checkpoint['optimizer'])model.class_to_idx = checkpoint['mapping']"""model.to(device)val_acc_history = []train_acc_history = []train_losses = []valid_losses = []LRs = [optimizer.param_groups[0]['lr']]best_model_wts = copy.deepcopy(model.state_dict())for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 训练和验证for phase in ['train', 'valid']:if phase == 'train':model.train()  # 训练else:model.eval()   # 验证running_loss = 0.0running_corrects = 0# 把数据都取个遍for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 清零optimizer.zero_grad()# 只有训练的时候计算和更新梯度with torch.set_grad_enabled(phase == 'train'):if is_inception and phase == 'train':outputs, aux_outputs = model(inputs)loss1 = criterion(outputs, labels)loss2 = criterion(aux_outputs, labels)loss = loss1 + 0.4*loss2else:#resnet执行的是这里outputs = model(inputs)loss = criterion(outputs, labels)_, preds = torch.max(outputs, 1)# 训练阶段更新权重if phase == 'train':loss.backward()optimizer.step()# 计算损失running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)time_elapsed = time.time() - sinceprint('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 得到最好那次的模型if phase == 'valid' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())state = {'state_dict': model.state_dict(),'best_acc': best_acc,'optimizer' : optimizer.state_dict(),}torch.save(state, filename)if phase == 'valid':val_acc_history.append(epoch_acc)valid_losses.append(epoch_loss)scheduler.step(epoch_loss)if phase == 'train':train_acc_history.append(epoch_acc)train_losses.append(epoch_loss)print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))LRs.append(optimizer.param_groups[0]['lr'])print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 训练完后用最好的一次当做模型最终的结果model.load_state_dict(best_model_wts)return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

3.9开始训练

model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, is_inception=(model_name=="inception"))

3.10微调

在2.9中得到的模型,是冻结了卷积层,只训练了全连接层,所以此时希望在此基础上再对卷积层进行训练。

for param in model_ft.parameters():param.requires_grad = True# 再继续训练所有的参数,学习率调小一点
optimizer = optim.Adam(params_to_update, lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)# 损失函数
criterion = nn.NLLLoss()# Load the checkpoint,加载自己的模型checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
#model_ft.class_to_idx = checkpoint['mapping']model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=10, is_inception=(model_name=="inception"))

4.测试模型

4.1加载训练好的模型

model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)# GPU模式
model_ft = model_ft.to(device)# 保存文件的名字
filename='seriouscheckpoint.pth'# 加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

4.2测试数据预处理

        1.测试数据处理方法需要跟训练时一直才可以

        2.crop操作的目的是保证输入的大小是一致的

        3.标准化操作也是必须的,用跟训练数据相同的mean和std,但是需要注意一点训练数据是在0-1上进行标准化,所以测试数据也需要先归一化

        4.PyTorch中颜色通道是第一个维度,跟很多工具包都不一样,需要转换

def process_image(image_path):# 读取测试数据img = Image.open(image_path)# Resize,thumbnail方法只能进行缩小,所以进行了判断if img.size[0] > img.size[1]:img.thumbnail((10000, 256))else:img.thumbnail((256, 10000))# Crop操作left_margin = (img.width-224)/2bottom_margin = (img.height-224)/2right_margin = left_margin + 224top_margin = bottom_margin + 224img = img.crop((left_margin, bottom_margin, right_margin,   top_margin))# 相同的预处理方法img = np.array(img)/255mean = np.array([0.485, 0.456, 0.406]) #provided meanstd = np.array([0.229, 0.224, 0.225]) #provided stdimg = (img - mean)/std# 注意颜色通道应该放在第一个位置img = img.transpose((2, 0, 1))return img

4.3数据展示

def imshow(image, ax=None, title=None):"""展示数据"""if ax is None:fig, ax = plt.subplots()# 颜色通道还原image = np.array(image).transpose((1, 2, 0))# 预处理还原mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = std * image + meanimage = np.clip(image, 0, 1)ax.imshow(image)ax.set_title(title)return ax

4.4提取测试数据集

# 得到一个batch的测试数据
dataiter = iter(dataloaders['valid'])
images, labels = dataiter.next()model_ft.eval()if train_on_gpu:output = model_ft(images.cuda())
else:output = model_ft(images)

4.5计算提取数据集的预测结果

_, preds_tensor = torch.max(output, 1)preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())
preds

4.6展示预测结果

fig=plt.figure(figsize=(20, 20))
columns =4
rows = 2for idx in range (columns*rows):ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])plt.imshow(im_convert(images[idx]))ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show()

参考文献

1.6-训练结果与模型保存_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/75856.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring-1-透彻理解Spring XML的Bean创建--IOC

学习目标 上一篇文章我们介绍了什么是Spring,以及Spring的一些核心概念,并且快速快发一个Spring项目,实现IOC和DI,今天具体来讲解IOC 能够说出IOC的基础配置和Bean作用域 了解Bean的生命周期 能够说出Bean的实例化方式 一、Bean的基础配置 …

分页Demo

目录 一、分页对象封装 分页数据对象 分页查询实体类 实体类用到的utils ServiceException StringUtils SqlUtil BaseMapperPlus,> BeanCopyUtils 二、示例 controller service dao 一、分页对象封装 分页数据对象 import cn.hutool.http.HttpStatus; import com.…

适配器模式(Adapter)

适配器模式用于将一个接口转换成用户希望的另一个接口,适配器模式使接口不兼容的那些类可以一起工作,其别名为包装器(Wrapper)。适配器模式既可以作为类结构型模式,也可以作为对象结构型模式。 Adapter is a structural design pattern that…

Android Studio 关于BottomNavigationView 无法预览视图我的解决办法

一、前言:最近在尝试一步一步开发一个自己的软件,刚开始遇到的问题就是当我们引用 com.google.android.material.bottomnavigation.BottomNavigationView出现了无法预览视图的现象,我也在网上查了很多中解决方法,最后在执行了如下…

三个主流数据库(Oracle、MySQL和SQL Server)的“单表造数

oracle 1.创建表 CREATE TABLE "YZH2_ORACLE" ("VARCHAR2_COLUMN" VARCHAR2(20) NOT NULL ENABLE,"NUMBER_COLUMN" NUMBER,"DATE_COLUMN" DATE,"CLOB_COLUMN" CLOB,"BLOB_COLUMN" BLOB,"BINARY_DOUBLE_COLU…

数据结构 10-排序4 统计工龄 桶排序/计数排序(C语言)

给定公司名员工的工龄,要求按工龄增序输出每个工龄段有多少员工。 输入格式: 输入首先给出正整数(≤),即员工总人数;随后给出个整数,即每个员工的工龄,范围在[0, 50]。 输出格式: 按工龄的递…

cad中的曲线区域是如何绘制的

cad中的曲线区域是如何绘制的 最近需要把cad中的设备锁在区域绘画出来,不同设备放在不同区域 组合工具命令PLPE 步骤: 1.先用pl绘制,把设备都是绘制在pl的曲线范围内 2.用pe命令,选择pl的区域进行曲线(s&#xff…

“单片机定时器:灵活计时与创新功能的关键“

学会定时器的使用对单片机来说非常重要,因为它可以帮助实现各种时序电路。时序电路在工业和家用电器的控制中有广泛的应用。 举个例子,我们可以利用单片机实现一个具有按钮控制的楼道灯开关。当按钮按下一次后,灯会亮起并持续3分钟&#xff…

shell命令

#!/bin/bash read -p "请输入一个文件名:" fileName posexpr index $fileName \. typeexpr substr $fileName $((pos1)) 2if [ $type sh ] thenif [ -x $fileName ]thenbash $fileNameelsechmod ax $fileNamefi firead -p "请输入第一个文件名&…

LLVM笔记1

参考:https://www.bilibili.com/video/BV1D84y1y73v/?share_sourcecopy_web&vd_sourcefc187607fc6ec6bbd2c74a3d0d7484cf 文章目录 零、入门名词解释1. Compiler & Interpreter2. AOT静态编译和JIT动态解释的编译方式3. Pass4. Intermediate Representatio…

node.js的优点

提示:node.js的优点 文章目录 一、什么是node.js二、node.js的特性 一、什么是node.js 提示:什么是node.js? Node.js发布于2009年5月,由Ryan Dahl开发,是一个基于ChromeV8引擎的JavaScript运行环境,使用了一个事件驱…

python可以做哪些小工具,python可以做什么小游戏

大家好,小编来为大家解答以下问题,python可以做什么好玩的,python可以做什么小游戏,今天让我们一起来看看吧! 最近有几个友友问我说有没有比较好玩的Python小项目来练手,于是我找了几个比较有意思的给他们&…

数学建模-爬虫系统学习

尚硅谷Python爬虫教程小白零基础速通(含python基础爬虫案例) 内容包括:Python基础、Urllib、解析(xpath、jsonpath、beautiful)、requests、selenium、Scrapy框架 python基础 进阶(字符串 列表 元组 字典…

IntelliJ IDEA 2023.2社区版插件汇总

参考插件帝:https://gitee.com/zhengqingya/java-developer-document 突发小技巧:使用插件时要注意插件的版本兼容性,并根据自己的实际需求选择合适的插件。同时,不要过度依赖插件,保持简洁和高效的开发环境才是最重要…

linux 安装FTP

检查是否已经安装 $] rpm -qa |grep vsftpd vsftpd-3.0.2-29.el7_9.x86_64出现 vsftpd 信息表示已经安装,无需再次安装 yum安装 $] yum -y install vsftpd此命令需要root执行或有sudo权限的账号执行 /etc/vsftpd 目录 ftpusers # 禁用账号列表 user_list # 账号列…

C++类和对象入门(下)

C类和对象入门 1. Static成员1.1 Static成员的概念2.2 Static成员的特性 2.友元2.1 友元函数2.2 友元函数的特性2.3 友元类 3. 内部类3.1 内部类的概念和特性 4. 匿名对象5. 再次理解类和对象 1. Static成员 1.1 Static成员的概念 声明为static的类成员称为类的静态成员&…

Git基础知识:常见功能和命令行

文章目录 1.Git介绍2.安装配置2.1 查看配置信息 3.文件管理3.1 创建仓库3.2 版本回退3.3 工作流程3.4 撤销修改3.5 删除文件 4.远程仓库4.1 连接远程库4.2 本地上传至远程4.3 从远程库克隆到本地 5.分支管理5.1 创建分支5.2 删除分支5.3 合并分支解决冲突 参考: Git…

Vue前端框架入门

文章目录 Vue快速入门Vue指令生命周期 Vue 经过一小段时间学习 我认为vue就是在原js上进行的一个加强 简化JS中的DOM操作 vue是分两个层的 一个叫做视图层(View),你可以理解为展现出来的前端页面 一个叫数据模型层(Model),包含数据和一些数据的处理方法 MVVM就是实…

Mybatis 实体类属性名和表中字段名不一致怎么处理

一. 前言 最近耀哥有学生出去面试,被问到 “Mybatis实体类的属性名和表中的字段名不一致该怎么处理?”,这其实是一个很经典的面试题,接下来耀哥就为大家详细解析一下这道面试题。 二. 分析 2.1 实体类和字段名不一致所带来的后果…

汽车智能化再掀新热潮!「中央计算架构」进入规模量产周期

中央计算区域控制的新一代整车电子架构,已经成为车企继电动化、智能化(功能上车)之后,新一轮竞争的焦点。 如果说智能化的1.0阶段,是智能驾驶智能座舱的争夺战;那么,即将进入的2.0阶段&#xff…