完整模型的训练套路

从心所欲

不逾矩

天大地大

皆可去

一、官方模型的初使用

使用VGG16模型

 VGG模型使用代码示例:

import torchvision.models
from torch import nndataset = torchvision.datasets.CIFAR10('/cifar10', False, transform=torchvision.transforms.ToTensor())vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16_false = torchvision.models.vgg16(pretrained=False)
print(vgg16_false)# 改造VGG,增加一层
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)# 改造vgg,修改一层
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

说明:

  1. pretrained=True:当设置为True时,模型将加载在大规模图像数据集(如ImageNet)上预训练的权重。这些预训练的权重经过了在大量图像上的训练,可以捕捉到通用的图像特征。通过加载预训练权重,可以将VGG模型初始化为在ImageNet上训练得到的状态,并且这些权重可以作为初始参数用于特定任务的微调或迁移学习。

  2. pretrained=False:当设置为False时,模型将使用随机初始化的权重。这意味着模型的权重没有经过预训练,需要从头开始进行训练。在这种情况下,模型将不会具备捕捉通用图像特征的能力,而是需要根据特定任务的数据进行训练。

pretrained=Truepretrained=False区别在于是否加载预训练的权重。如果你想要在特定任务上使用VGG模型,并且你的任务与图像分类或特征提取相关,那么通常建议使用pretrained=True,以便利用预训练权重的优势。如果你的任务与图像分类或特征提取无关,或者你希望从头开始训练模型以适应特定数据集,那么可以选择pretrained=False

二、模型的保存与加载

模型的保存:

两种保存模式,官方推荐第二种,只保存参数,不保存模型

import torch
import torchvision.modelsvgg16 = torchvision.models.vgg16(pretrained=False)# 保存方式1: 既保存模型结构,也保存参数
torch.save(vgg16, 'vgg16_model1.pth')# 保存方式2:把参数保存成字典,不保存结构(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_model2.pth')print("end")

模型的加载:
 

import torch
import torchvision.models# 加载方式1 - 保存方式1
model = torch.load('vgg16_model1.pth')# 加载方式2 - 保存方式2
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('vgg16_model2.pth'))

三、完整的模型训练套路

以CIFAR10数据集来一个完整的模型训练。

代码示例:

model.py

from torch import nn# 搭建神经网络
class Lh(nn.Module):def __init__(self):super(Lh, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x

train.py

import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom model import Lh# 准备数据集
train_data = torchvision.datasets.CIFAR10('./cifar10', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10('./cifar10', train=False, transform=torchvision.transforms.ToTensor(), download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用DataLoader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 搭建神经网络 - 10分类
lh = Lh()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(lh.parameters(), lr=learning_rate)# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练轮数
epoch = 10# 添加tensorboard
writer = SummaryWriter("train_logs")for i in range(epoch):print("-----第{}轮训练开始了-----".format(i + 1))# 训练步骤开始for data in train_dataloader:imgs, tragets = dataoutput = lh(imgs)loss = loss_fn(output, tragets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, tragets = dataoutput = lh(imgs)loss = loss_fn(output, tragets)total_test_loss += lossaccuracy = (output.argmax(1) - - tragets).sum()total_accuracy += accuracyprint("整体测试机上误差:{}".format(total_test_loss))print("整体测试机上的正确率:{}".format(total_accuracy / test_data_size))writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy / total_test_step)total_test_step += 1# torch.save(lh, "lhy_{}.pth".format(i))# print("模型已保存")writer.close()

输出结果:

 在tensorboard打开

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/76315.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Electron 开发,报handshake failed; returned -1, SSL error code 1,错误

代码说明 在preload.js代码中,暴露参数给渲染线程renderer.js访问, renderer.js 报:ERROR:ssl_client_socket_impl.cc(978)] failed; returned -1, SSL error code 1,错误 问题原因 如题所说,跨进程传递消息,这意味…

[Docker实现测试部署CI/CD----自由风格的CI操作[最终架构](5)]

目录 11、自由风格的CI操作(最终)Jenkins容器化实现方案修改 docker.sock 权限修改 Jenkins 启动命令后重启 Jenkins构建镜像推送到Harbor修改 daemon.json 文件Jenkins 删除构建后操作Jenkins 添加 shell 命令重新构建 Jenkins通知目标服务器拉取镜像目…

redis的安装和配置

一、nosql 二、redis的安装和配置 redis的安装: redis常见配置: 配置文件redis.conf

【数据结构】这堆是什么

目录 1.二叉树的顺序结构 2.堆的概念及结构 3.堆的实现 3.1 向上调整算法与向下调整算法 3.2 堆的创建 3.3 建堆的空间复杂度 3.4 堆的插入 3.5 堆的删除 3.6 堆的代码的实现 4.堆的应用 4.1 堆排序 4.2 TOP-K问题 首先,堆是一种数据结构,一种特…

在SAP中使用苹果手机进行条码扫描

适用于iOS的Liquid UI支持使用内置摄像头或第三方设备(如Linea Pro)进行条形码扫描。它使您能够通过单击在任何 SAP 输入字段中填充数据。它支持:一维和二维条码扫描。此外,编辑扫描的数据或在扫描后对操作进行编程,以…

2023牛客暑期多校训练营6-C-idol!!

奇数的双阶乘等于小于等于本身的奇数的乘积,偶数的双阶乘等于小于等于本身的非零偶数的乘积。 思路:考虑末位0的个数,我们能想到的最小两数相乘有零的就是2*5,所以本题我们思路就是去找因子2的个数以及因子5的个数,2的…

kubernetes基于helm部署gitlab

kubernetes基于helm部署gitlab 这篇博文介绍如何在 Kubernetes 中使用helm部署 GitLab。 先决条件 已运行的 Kubernetes 集群负载均衡器,为ingress-nginx控制器提供EXTERNAL-IP,本示例使用metallb默认存储类,为gitlab pods提供持久化存储&…

mysql存储过程定时调度

假设我们要创建一个简单的数据库,其中包含两张表:students 表和 courses 表,以及一个存储过程用于插入学生数据。下面是完整的建表语句、插入语句和存储过程: 1】建表 -- 创建 courses 表 CREATE TABLE courses (course_id INT …

回调函数的简单用例

列举项目中一个简单的回调函数用例 ①用MsgInterface_t定义一个结构体s_Lin_MsgInterface,包含两个回调函数成员: ②确定结构体下的回调函数成员的参数: ③传入实参,确定结构体下的回调函数成员的函数名: ④最终回…

【网络】网络层(IP协议)

目录 一、基本概念 二、协议头格式 三、网段划分 四、特殊的IP地址 五、IP地址的数量限制 六、私有IP地址和公网IP地址 七、路由 一、基本概念 IP协议:提供一种能力, 将数据从A主机送到B主机,(TCP协议:确保IP协议…

Jmeter函数助手(一)随机字符串(RandomString)

一、目标 实现一个请求单次调用,请求体里多个集合中的相同参数(zxqs)值随机从序列{01、02、03、03、04、05、06、07、08}中取 若使用CSV数据文件、用户参数等参数化手段,单次执行请求,请求体里多个集合中的相同参数&a…

Django实现音乐网站 ⑹

使用Python Django框架制作一个音乐网站, 本篇主要是在添加编辑过程中对后台歌手功能优化及表模型名称修改、模型继承内容。 目录 表模型名称修改 模型继承 创建抽象基类 其他模型继承 更新表结构 歌手新增、编辑优化 表字段名称修改 隐藏单曲数和专辑数 姓…

LAXCUS分布式操作系统引领科技潮流,进入百度首页

信息源自某家网络平台,以下原样摘抄贴出。 随着科技的飞速发展,分布式操作系统做为通用基础平台,为大数据、高性能计算、人工智能提供了强大的数据和算力支持,已经成为了当今计算机领域的研究热点。近日,一款名为LAXCU…

Mysql on duplicate key update用法及优缺点

在实际应用中,经常碰到导入数据的功能,当导入的数据不存在时则进行添加,有修改时则进行更新, 在刚碰到的时候,一般思路是将其实现分为两块,分别是判断增加,判断更新,后来发现在mysql…

【ASP.NET MVC】使用动软(二)(10)

一、添加动软生成工程 按前文添加动态到工程 双击动软 完成新建数据库服务器后 ,需要关闭重新打开 选择简单三层,注意保存位置 注意切换数据库: 生成后拷贝五个文件夹到工程目录 注意目录结构: 添加四个项目到原来的工程&…

【已解决】安装win7系统,“Windows安装程序无法将Windows配置在此计算机的硬件上运行”

问题: 安装windows7时报错:“Windows安装程序无法将Windows配置在此计算机的硬件上运行” 解决办法: 方法一 shiftF10 调出命令提示行,输入cd oobe 然后再输入msoobe 回车,就可以继续进行下一步了。 如果方法一不…

docker删除容器(步骤详解)

要在Docker中删除容器,需要使用命令docker rm。 下面是详细步骤: 1. 首先,使用docker ps命令查看当前正在运行的容器。这个命令会列出所有正在运行的容器的ID、名称、状态等信息。 如果没有正在运行的容器可以通过docker ps -a 查看当前所…

使用文心一言等智能工具指数级提升嵌入式/物联网(M5Atom/ESP32)和机器人操作系统(ROS1/ROS2)学习研究和开发效率

以M5AtomS3为例,博客撰写效率提升10倍以上: 0. Linux环境Arduino IDE中配置ATOM S3_zhangrelay的博客-CSDN博客 1. M5ATOMS3基础01按键_zhangrelay的博客-CSDN博客 2. M5ATOMS3基础02传感器MPU6886_zhangrelay的博客-CSDN博客 3. M5ATOMS3基础03给RO…

.Net6 Web Core API 配置 Autofac 封装 --- 依赖注入

目录 一、NuGet 包导入 二、Autofac 封装类 三、Autofac 使用 四、案例测试 下列封装 采取程序集注入方法, 单个依赖注入, 也适用, 可<依赖注入>的地方配置 一、NuGet 包导入 Autofac Autofac.Extensions.DependencyInjection Autofac.Extras.DynamicProxy 二、Auto…

Python元编程-装饰器介绍、使用

目录 一、Python元编程装饰器介绍 二、装饰器使用 1. 实现认证和授权功能 2.实现缓存功能 3.实现日志输出功能 三、附录 1. logging.basicConfig介绍 2. 精确到毫秒&#xff0c;打印时间 方法一&#xff1a;使用datetime 方法二&#xff1a;使用time 一、Python元编程…