Pytorch网络模型训练

现有网络模型的使用与修改

vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别print(vgg16_true)

以下是vgg16预训练模型的输出 

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

预训练模型的输出从1000类别转为10类别

import torchvision
from torch import nn
# 因为数据集过大,所以注释掉此行代码
# train_data = torchvision.datasets.ImageNet("./data_image_net", split='train', download=True,
#                                            transform=torchvision.transforms.ToTensor())vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别print(vgg16_true)# vgg16_true.add_module("add_linear", nn.Linear(1000, 10))
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
# 在预训练模型的最后添加了一个新的全连接层,用于将最后的输出转化为10个类别
print(vgg16_true)print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
# 未预训练模型的最后一层的输出特征数更改为了10
print(vgg16_false)

网络模型的保存与读取

加载未预训练的模型

vgg16 = torchvision.models.vgg16(pretrained=False)

方式一

# 保存方式1  保存的模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pyth")#读取方式1
model = torch.load("vgg16_method1.pth")

方式二

# 保存方式2  不再保存模型结构,而是保存模型的参数为字典结构    推荐
torch.save(vgg16.state_dict(), "vgg16_method2.pyth")# 方式2,加载模型
# model = torch.load("vgg16_method2.pth")     #这样输出的是字典类型
# print(model)
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))      # 将其恢复为网络模型
print(vgg16)

完整的模型训练套路

准备数据集

# 准备数据集
train_data = torchvision.datasets.CIFAR10("../data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),download=True)train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为{}".format(train_data_size))    # 50000
print("测试数据集的长度为{}".format(test_data_size))     # 10000# 利用Dataloader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

创建网络模型

# 创建网络模型  神经网络的代码在train_module文件
tudui = Tudui()

train_module文件

# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()# 简化操作,并且按顺序进行操作self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x

构建损失函数

# 损失函数
loss_fn = nn.CrossEntropyLoss()

构建优化器

# 优化器
# 如果学习率过大,模型可能会在最小值附近震荡而无法收敛;如果学习率过小,模型训练可能会过于缓慢
learning_rate = 0.01
# 使用随机梯度下降算法来更新模型的权重
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

设置训练集参数

# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

添加tensorboard

# 将数据写入 TensorBoard 可视化的日志文件中
writer = SummaryWriter("../logs_train")

训练步骤

# tudui.train()
for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad()# 将优化器中的梯度缓存(如果有的话)清零loss.backward()# 计算损失函数(loss)相对于模型参数的梯度optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:# .item()是将tensor张量变为正常的数字print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))# loss.item()是当前步骤的损失值writer.add_scalar("train_loss", loss.item(), total_train_step)# 使用add_scalar可以将一个标量添加到之前的所有标量值中,# 这样就可以在TensorBoard中绘制一个标量随时间变化的图表

测试步骤

# 测试步骤开始
# tudui.eval()
total_test_loss = 0
total_accuracy = 0
# 不会对以下的代码进行调优
with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()# argmax(1)是横向看,argmax(0)是纵向看accuracy = (outputs.argmax(1) == targets).sum()# argmax在找到模型预测的最大概率对应的类别# 预测正确的个数total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))
print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))
# 测试集上的总损失
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/180541.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2023全新小程序广告流量主奖励发放系统源码 流量变现系统 带安装教程

2023全新小程序广告流量主奖励发放系统源码 流量变现系统 分享软件,吃瓜视频,或其他资源内容,通过用户付费买会员来变现,用户需要付费,有些人喜欢白嫖,所以会流失一部分用户,所以就写了这个系统…

CSGO饰品价格暴跌的原因分析

CSGO饰品暴跌3个月,盘点6大原因 今天我们来聊一下CSGO饰品市场的情况。大部分装备从3月份开始就一直持续走低,到现在已经是7月份了,还有部分饰品呈阴跌趋势。整个市场沉寂一片,还有些悲观主义者天天在吆喝:市场崩盘了&…

【论文阅读笔记】GLM-130B: AN OPEN BILINGUAL PRE-TRAINEDMODEL

Glm-130b:开放式双语预训练模型 摘要 我们介绍了GLM-130B,一个具有1300亿个参数的双语(英语和汉语)预训练语言模型。这是一个至少与GPT-3(达芬奇)一样好的100b规模模型的开源尝试,并揭示了如何成功地对这种规模的模型进行预训练。在这一过程中&#xff0…

香港金融科技周2023:AIGC重塑金融形态

10月31日,由香港财经事务及库务局与投资推广署主办的“香港金融科技周2023大湾区专场”盛大启幕。中国AI决策领先企业萨摩耶云科技集团创始人、董事长兼 CEO林建明受邀参加圆桌会议,与中国内地、香港以及全球金融科技行业顶尖人才、创新企业、监管机构和…

【C++】特殊类设计

文章目录 一、设计一个类,不能被拷贝二、设计一个类,不能被继承三、设计一个类,只能在栈上创建对象四、设计一个类,只能在堆上创建对象五、设计一个类,只能创建一个对象(单例模式) 在某些特殊的场景下,我们…

“免单优选模式:引爆电商革命,颠覆传统购物体验!“

免单优选模式是一种新型的电商销售模式,其核心理念是通过降低商品售价、设置阶梯式奖励以及利用社交关系链,激发消费者购买欲望,实现销售快速增长。 1、合法合规,不存在多层级奖励。 在免单优选模式中,平台不设置多层…

深度学习_8_对Softmax回归的理解

回归问题,例如之前做房子价格预测的线性回归问题 而softmax回归是一个分类问题,即给定一个图片,从猫狗两种动物类别中选出最可靠的那种答案,这个是两类分类问题,因为狗和猫是两类 上述多个输出可以这样理解,假设一个图…

oracle查询数据库内全部的表名、列明、注释、数据类型、长度、精度等

Oracle查询数据库内全部的表名、列明、注释、数据类型、长度、精度 SELECT a.TABLE_NAME 表名, row_number() over(partition by a.TABLE_NAME order by a.COLUMN_NAME desc) 字段顺序,a.COLUMN_NAME 列名, b.COMMENTS 注释,a.DATA_TYPE 数据类型, a.DATA_LENGTH 长度,DATA_SC…

【后端开发】手写一个简单的线程池

半同步半异步线程池 半同步半异步线程池分为三层: 同步服务层 —— 处理来自上层的任务请求,将它们加入到排队层中等待处理。 同步排队层 —— 实际上是一个“同步队列”,允许多线程添加/取出任务,并保证线程安全。 异步服务层…

R语言657中单色colors颜色索引表---全平台可用

R语言657中单色colors颜色索引表—全平台可用

NLP之LSTM与BiLSTM

文章目录 代码展示代码解读双向LSTM介绍(BiLSTM) 代码展示 import pandas as pd import tensorflow as tf tf.random.set_seed(1) df pd.read_csv("../data/Clothing Reviews.csv") print(df.info())df[Review Text] df[Review Text].astyp…

【计算机网络实验/wireshark】tcp建立和释放

wireshark开始捕获后,浏览器打开xg.swjtu.edu.cn,网页传输完成后,关闭浏览器,然后停止报文捕获。 若捕获不到dns报文,先运行ipconfig/flushdns命令清空dns缓存 DNS报文 设置了筛选条件:dns 查询报文目的…

17、Flink 之Table API: Table API 支持的操作(1)

Flink 系列文章 1、Flink 部署、概念介绍、source、transformation、sink使用示例、四大基石介绍和示例等系列综合文章链接 13、Flink 的table api与sql的基本概念、通用api介绍及入门示例 14、Flink 的table api与sql之数据类型: 内置数据类型以及它们的属性 15、Flink 的ta…

代码随想录训练营第60天 | 503.下一个更大元素II ● 42. 接雨水● 84.柱状图中的最大矩形

503.下一个更大元素II 题目链接:https://leetcode.com/problems/next-greater-element-ii/ 解法: 由于是循环数组,可以直接把两个数组拼接在一起,然后使用单调栈求下一个最大值。 写法上,可以巧妙一些&#xff0c…

【马蹄集】—— 百度之星 2023

百度之星 2023 目录 BD202301 公园⭐BD202302 蛋糕划分⭐⭐⭐BD202303 第五维度⭐⭐ BD202301 公园⭐ 难度:钻石    时间限制:1秒    占用内存:64M 题目描述 今天是六一节,小度去公园玩,公园一共 N N N 个景点&am…

快速灵敏的 Flink1

一、flink单机安装 1、解压 tar -zxvf ./flink-1.13.2-bin-scala_2.12.tgz -C /opt/soft/ 2、改名字 mv ./flink-1.13.2/ ./flink1132 3、profile配置 #FLINK export FLINK_HOME/opt/soft/flink1132 export PATH$FLINK_HOME/bin:$PATH 4、查看版本 flink --version 5、…

轻量封装WebGPU渲染系统示例<14>- 多线程模型载入(源码)

当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/main/src/voxgpu/sample/ModelLoadTest.ts 此示例渲染系统实现的特性: 1. 用户态与系统态隔离。 细节请见:引擎系统设计思路 - 用户态与系统态隔离-CSDN博客 2. 高频调用与低频调用隔离。 …

C语言--判断一个年份是否是闰年(详解)

一.闰年的定义 闰年是指在公历(格里高利历)中,年份可以被4整除但不能被100整除的年份,或者可以被400整除的年份。简单来说,闰年是一个比平年多出一天的年份,即2月有29天。闰年的目的是校准公历与地球公转周…

CH10_简化条件逻辑

分解条件表达式(Decompose Conditional) if (!aDate.isBefore(plan.summerStart) && !aDate.isAfter(plan.summerEnd))charge quantity * plan.summerRate; elsecharge quantity * plan.regularRate plan.regularServiceCharge;if (summer())…

【蓝桥杯省赛真题42】Scratch舞台特效 蓝桥杯少儿编程scratch图形化编程 蓝桥杯省赛真题讲解

目录 scratch舞台特效 一、题目要求 编程实现 二、案例分析 1、角色分析