完整网络模型训练(一)

文章目录

    • 一、网络模型的搭建
    • 二、网络模型正确性检验
    • 三、创建网络函数

一、网络模型的搭建

以CIFAR10数据集作为训练例子

准备数据集:

#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)

查看数据集的长度:

train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")

运行结果:
在这里插入图片描述

利用DataLoader来加载数据集:

train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

搭建CIFAR10数据集神经网络:
在这里插入图片描述
卷积层【1】代码解释:
#第一个数字3表示inputs(可以看到图中为3),第二个数字32表示outputs(图中为32)
#第三个数字5为卷积核(图中为5),第四个数字1表示步长(stride)
#第五个数字表示padding,需要计算,计算公式:
在这里插入图片描述

nn.Conv2d(3, 32, 5, 1, 2)

最大池化代码解释:
#数字2表示kernel卷积核

nn.MaxPool2d(2)

读图
卷积层【1】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述

最大池化【1】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述
卷积层【2】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述
以此类推

展平:
在这里插入图片描述
Flatten后它会变成64*4 *4的一个结果

线性输出:
在这里插入图片描述
线性输入是64*4 *4,线性输出是64,故如下代码
nn.LInear(64 *4 *4,64)

继续线性输出
在这里插入图片描述
nn.LInear(64,10)

搭建网络完整代码:

class Sen(nn.Module):def __init__(self):super(Sen, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1 ,2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return x

二、网络模型正确性检验

if __name__ == '__main__':sen = Sen()input = torch.ones((64, 3, 32, 32))output = sen(input)print(output.shape)

注释:

input = torch.ones((64, 3, 32, 32))

这一行代码的含义是:创建一个大小为 (64, 3, 32, 32) 的全 1 张量,数据类型为 torch.float32。
64:这是批次大小,代表输入有 64 张图片。
3:这是图片的通道数,通常为 RGB 图像的三个通道 (红、绿、蓝)。
32, 32:这是图片的高和宽,表示每张图片的尺寸为 32x32 像素。
torch.ones 函数用于生成一个全 1 的张量,这里的张量形状适合用于输入图像分类或卷积神经网络(CNN)中常见的 CIFAR-10 或类似的 32x32 像素图像数据。

运行结果:
在这里插入图片描述
可以得到成功变成了【64, 10】的结果。

三、创建网络函数

创建网络模型:

sen = Sen()

搭建损失函数:

loss_fn = nn.CrossEntropyLoss()

优化器:

learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)

优化器注释:
使用随机梯度下降(SGD)优化器
learning_rate = 1e-2 这里的1e-2代表的是:1 x (10)^(-2) = 1/100 = 0.01

记录训练的次数:

total_train_step = 0

记录测试的次数:

total_test_step = 0

训练的轮数:

epoch= 10

进行循环训练:

for i in range(epoch):print(f"第{i+1}轮训练开始")for data in train_dataloader:imgs, targets = dataoutputs = sen(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1print(f"训练次数:{total_train_step},Loss:{loss.item()}")

注释:
imgs, targets = data是解包数据,imgs 是输入图像,targets 是目标标签(真实值)
outputs = sen(imgs)将输入图像传入模型 ‘sen’,得到模型的预测输出 outputs
loss = loss_fn(outputs, targets)计算损失值(Loss),loss_fn 是损失函数,它比较outputs的值与targets 是目标标签(真实值)的误差
optimizer.zero_grad()清除优化器中上一次计算的梯度,以免梯度累积
loss.backward()反向传播,计算损失相对于模型参数的梯度
optimizer.step()使用优化器更新模型的参数,以最小化损失
loss.item() 将张量转换为 Python 的数值
loss.item演示:

import torch
a = torch.tensor(5)
print(a)
print(a.item())

运行结果:
在这里插入图片描述
因此可以得到:item的作用是将tensor变成真实数字5

本章节完整代码展示:

import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoaderclass Sen(nn.Module):def __init__(self):super(Sen, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1 ,2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return x
#准备数据集
#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)#length长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)sen = Sen()#损失函数
loss_fn = nn.CrossEntropyLoss()#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epoch= 10for i in range(epoch):print(f"第{i+1}轮训练开始")for data in train_dataloader:imgs, targets = dataoutputs = sen(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1print(f"训练次数:{total_train_step},Loss:{loss.item()}")

运行结果:
在这里插入图片描述
可以看到训练的损失函数在一直进行修正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/435980.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端工程规范-2:JS代码规范(Prettier + ESLint)

Prettier 和 ESLint 是两个在现代 JavaScript 开发中广泛使用的工具,它们结合起来可以提供以下作用和优势: 代码格式化和风格统一: Prettier 是一个代码格式化工具,能够自动化地处理代码的缩进、空格、换行等格式问题,…

【C++算法】8.双指针_三数之和

文章目录 题目链接:题目描述:解法C 算法代码:图解 题目链接: 15.三数之和 题目描述: 解法 解法一:排序暴力枚举利用set去重O(n3) 例如nums[-1,0,1,2,-1&…

【C++篇】领略模板编程的进阶之美:参数巧思与编译的智慧

文章目录 C模板进阶编程前言第一章: 非类型模板参数1.1 什么是非类型模板参数?1.1.1 非类型模板参数的定义 1.2 非类型模板参数的注意事项1.3 非类型模板参数的使用场景示例:静态数组的实现 第二章: 模板的特化2.1 什么是模板特化?2.1.1 模板…

基于单片机的催眠电路控制系统

** 文章目录 前言一 概要功能设计设计思路 软件设计效果图 程序文章目录 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师,一名热衷于单片机技术探索与分享的博主、专注于 精通51/STM32/MSP430/AVR等单片机设计 主…

Apache DolphinScheduler-1.3.9源码分析(一)

引言 随着大数据的发展,任务调度系统成为了数据处理和管理中至关重要的部分。Apache DolphinScheduler 是一款优秀的开源分布式工作流调度平台,在大数据场景中得到广泛应用。 在本文中,我们将对 Apache DolphinScheduler 1.3.9 版本的源码进…

html+css(如何用css做出京东页面,静态版)

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>京东</title><link rel"stylesheet&q…

AR 眼镜之-蓝牙电话-来电铃声与系统音效

目录 &#x1f4c2; 前言 AR 眼镜系统版本 蓝牙电话 来电铃声 系统音效 1. &#x1f531; Android9 原生的来电铃声&#xff0c;走的哪个通道&#xff1f; 2. &#x1f4a0; Android9 原生的来电铃声&#xff0c;使用什么播放&#xff1f; 2.1 来电铃声创建准备 2.2 来…

联宇集团:如何利用CRM实现客户管理精细化与业务流程高效协同

在全球化的浪潮中&#xff0c;跨境电商正成为国际贸易的新引擎。作为领先的跨境电商物流综合服务商&#xff0c;广东联宇物流有限公司(以下称“联宇集团”)以其卓越的物流服务和前瞻的数字化战略&#xff0c;在全球市场中脱颖而出。本文将基于联宇集团搭建CRM系统的实际案例&am…

PV大题--专题突破

写在前面&#xff1a; PV大题考查使用伪代码控制进程之间的同步互斥关系&#xff0c;它需要我们一定的代码分析能力&#xff0c;算法设计能力&#xff0c;有时候会给你一段伪代码让你补全使用信号量控制的操作&#xff0c;请一定不要相信某些人告诉你只要背一个什么模板&#…

国庆偷偷卷!小众降维!POD-Transformer多变量回归预测(Matlab)

目录 效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现POD-Transformer多变量回归预测&#xff0c;本征正交分解数据降维融合Transformer多变量回归预测&#xff0c;使用SVD进行POD分解&#xff08;本征正交分解&#xff09;&#xff1b; 2.运行环境Matlab20…

MobaXterm基本使用 -- 服务器状态、批量操作、显示/切换中文字体、修复zsh按键失灵

监控服务器资源 参考网址&#xff1a;https://www.cnblogs.com/144823836yj/p/12126314.html 显示效果 MobaXterm提供有这项功能&#xff0c;在会话窗口底部&#xff0c;显示服务器资源使用情况 如内存、CPU、网速、磁盘使用等&#xff1a; &#xff08;完整窗口&#xff0…

BEVDet---论文+源码解读

论文链接&#xff1a;https://arxiv.org/pdf/2112.11790.pdf&#xff1b; Github仓库源码&#xff1a;https://github.com/HuangJunJie2017/BEVDet&#xff1b; BEVDet这篇论文主要是提出了一种基于BEV空间下的3D目标检测范式&#xff0c;BEVDet算法模型的整体流程图如下&…

汽车总线之---- LIN总线

Introduction LIN总线的简介&#xff0c;对于传统的这种点对点的连接方式&#xff0c;我们可以看到ECU相关的传感器和执行器是直接连接到ECU的&#xff0c;当传感器和执行器的数量较少时&#xff0c;这样的连接方式是能满足要求的&#xff0c;但是随着汽车电控功能数量的不断增…

基于单片机的指纹打卡系统

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 基于STC89C52RC&#xff0c;采用两个按键替代指纹&#xff0c;一个按键按下&#xff0c;LCD12864显示比对成功&#xff0c;则 采用ULN2003驱动步进电机转动&#xff0c;表示开门&#xff0c;另一个…

RTMP、RTSP直播播放器的低延迟设计探讨

技术背景 没有多少开发者会相信RTMP或RTSP播放器&#xff0c;延迟会做到150-300ms内&#xff0c;除非测试过大牛直播SDK的&#xff0c;以Android平台启动轻量级RTSP服务和推送RTMP&#xff0c;然后Windows分别播放RTSP和RTMP为例&#xff0c;整体延迟如下&#xff1a; 大牛直播…

深度学习后门攻击分析与实现(二)

前言 在本系列的第一部分中&#xff0c;我们已经掌握了深度学习中的后门攻击的特点以及基础的攻击方式&#xff0c;现在我们在第二部分中首先来学习深度学习后门攻击在传统网络空间安全中的应用。然后再来分析与实现一些颇具特点的深度学习后门攻击方式。 深度学习与网络空间…

探索甘肃非遗:Spring Boot网站开发案例

1 绪论 1.1 研究背景 当前社会各行业领域竞争压力非常大&#xff0c;随着当前时代的信息化&#xff0c;科学化发展&#xff0c;让社会各行业领域都争相使用新的信息技术&#xff0c;对行业内的各种相关数据进行科学化&#xff0c;规范化管理。这样的大环境让那些止步不前&#…

SpringBoot框架下体育馆管理系统的构建

1引言 1.1课题背景 当今时代是飞速发展的信息时代。在各行各业中离不开信息处理&#xff0c;这正是计算机被广泛应用于信息管理系统的环境。计算机的最大好处在于利用它能够进行信息管理。使用计算机进行信息控制&#xff0c;不仅提高了工作效率&#xff0c;而且大大的提高了其…

工具介绍---效率高+实用

Visual Studio Code (VS Code) 功能特点&#xff1a; 智能代码提示&#xff1a;内置的智能代码提示功能可以自动完成函数、变量等的输入&#xff0c;提高代码编写速度。插件丰富&#xff1a;支持成千上万的扩展插件&#xff0c;例如代码片段、主题、Linting等&#xff0c;能够…

通信工程学习:什么是CSMA/CD载波监听多路访问/冲突检测

CSMA/CD&#xff1a;载波监听多路访问/冲突检测 CSMA/CD&#xff08;Carrier Sense Multiple Access/Collision Detect&#xff09;&#xff0c;即载波监听多路访问/冲突检测&#xff0c;是一种用于数据通信的介质访问控制协议&#xff0c;广泛应用于局域网&#xff08;特别是以…