PyTorch之完整的神经网络模型训练

简单的示例:

在PyTorch中,可以使用nn.Module类来定义神经网络模型。以下是一个示例的神经网络模型定义的代码:

import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 定义神经网络的层和参数self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(32 * 14 * 14, 128)self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.maxpool(x)x = x.view(x.size(0), -1)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.softmax(x)return x

在上面的示例中,定义了一个名为MyModel的神经网络模型,继承自nn.Module类。在__init__方法中,我们定义了模型的层和参数。具体来说:

  • 代码定义了一个卷积层,输入通道数为1,输出通道数为32,卷积核大小为3x3,步长为1,填充为1。
  • 定义了一个ReLU激活函数,用于在卷积层之后引入非线性性质。
  • 定义了一个最大池化层,池化核大小为2x2,步长为2。
  • 定义了一个全连接层,输入大小为32x14x14(经过卷积和池化后的特征图大小),输出大小为128。
  • 定义了另一个全连接层,输入大小为128,输出大小为10。
  • 定义了一个softmax函数,用于将模型的输出转换为概率分布。

forward方法中,定义了模型的前向传播过程。具体来说:

  • x = self.conv1(x): 将输入张量传递给卷积层进行卷积操作。
  • x = self.relu(x): 将卷积层的输出通过ReLU激活函数进行非线性变换。
  • x = self.maxpool(x): 将ReLU激活后的特征图进行最大池化操作。
  • x = x.view(x.size(0), -1): 将池化后的特征图展平为一维,以适应全连接层的输入要求。
  • x = self.fc1(x): 将展平后的特征向量传递给第一个全连接层。
  • x = self.relu(x): 将第一个全连接层的输出通过ReLU激活函数进行非线性变换。
  • x = self.fc2(x): 将第一个全连接层的输出传递给第二个全连接层。
  • x = self.softmax(x): 将第二个全连接层的输出通过softmax函数进行归一化,得到每个类别的概率分布。

这个示例展示了一个简单的卷积神经网络模型,适用于处理单通道的图像数据,并输出10个类别的分类结果。可以根据自己的需求和数据特点来定义和修改神经网络模型。

接下来将用于实际的数据集进行训练:

以下是基于CIFAR10数据集的神经网络训练模型:

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from nn_mode import *#准备数据集
train_data=torchvision.datasets.CIFAR10(root='../chap4_Dataset_transforms/dataset',train=True,transform=torchvision.transforms.ToTensor())
test_data=torchvision.datasets.CIFAR10(root='../chap4_Dataset_transforms/dataset',train=False,transform=torchvision.transforms.ToTensor())
#输出数据集的长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print(train_data_size)
print(test_data_size)
#加载数据集
train_loader=DataLoader(dataset=train_data,batch_size=64)
test_loader=DataLoader(dataset=test_data,batch_size=64)
#创建神经网络
sjnet=Sjnet()#损失函数
loss_fn=nn.CrossEntropyLoss()
#优化器
learn_lr=0.01#便于修改
YHQ=torch.optim.SGD(sjnet.parameters(),lr=learn_lr)#设置训练网络的参数
train_step=0#训练次数
test_step=0#测试次数
epoch=10#训练轮数writer=SummaryWriter('wanzheng_logs')for i in range(epoch):print("第{}轮训练".format(i+1))#开始训练for data in train_loader:imgs,targets=dataoutputs=sjnet(imgs)loss=loss_fn(outputs,targets)#优化器YHQ.zero_grad()  # 将神经网络的梯度置零,以准备进行反向传播loss.backward()  # 执行反向传播,计算神经网络中各个参数的梯度YHQ.step()  # 调用优化器的step()方法,根据计算得到的梯度更新神经网络的参数,完成一次参数更新train_step =train_step+1if train_step%100==0:print('训练次数为:{},loss为:{}'.format(train_step,loss))writer.add_scalar('train_loss',loss,train_step)#开始测试total_loss=0with torch.no_grad():#上下文管理器,用于指示在接下来的代码块中不计算梯度。for data in test_loader:imgs,targets=dataoutputs = sjnet(imgs)loss = loss_fn(outputs, targets)#使用损失函数 loss_fn 计算预测输出与目标之间的损失。total_loss=total_loss+loss#将当前样本的损失加到总损失上,用于累积所有样本的损失。print('整体测试集上的loss:{}'.format(total_loss))writer.add_scalar('test_loss', total_loss, test_step)test_step = test_step+1torch.save(sjnet,'sjnet_{}.pth'.format(i))print("模型已保存!")writer.close()

 其神经网络训练以及测试时的损失值使用TensorBoard进行展示,如图所示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/275029.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

云计算OpenStack KVM迁移

动态迁移 static migration 静态迁移 cold migration 冷迁移 offline migration 离线迁移 live migration 动态迁移 hot migration 热迁移 online migration 在线迁移 衡量 整体迁移时间 服务器停机时间 性能影响(迁移后和其它客户机) 特点 负载均衡 解除硬件依赖…

企智汇数字化项目管理平台,助力企业高效项目管理!数字化转型必备!

数字化项目管理平台是一种集成了先进项目信息技术的管理工具,旨在帮助组织更有效地管理项目,实现项目目标的顺利完成。以下是企智汇数字化项目管理平台的一些核心特点和功能: 1. 统一的信息管理:企智汇数字化项目管理平台能够将项…

OpenCASCADE开发指南<四>:OCC 数据类型和句柄

一个软件首先要规定能处理的数据类型, 其次要实现三项最基本的功能——引用管理、内存管理和异常管理。在 OCC 中,这三项功能分别对应基础类中的句柄、内存管理器和异常类。 1 数据类型 在基本概念篇里,已经介绍了 OCC 数据类型的分类&…

网络工程师——2024自学

一、怎样从零开始学习网络工程师 当今社会,人人离不开网络。整个IT互联网行业,最好入门的,网络工程师算是一个了。 什么是网络工程师呢,简单来说,就是互联网从设计、建设到运行和维护,都需要网络工程师来…

工会排队模式:引领创新消费体验的新潮流

在互联网和电子商务的浪潮下,消费者的购物需求与期待正在持续升级。为了迎合这一趋势,工会排队模式应运而生,以其独特的消费体验方式引领市场潮流。 工会排队模式打破了传统电商的桎梏,通过现金返还机制为购物赋予了新的定义。这一…

如何使用 request-promise 在发送请求时使用代理ip?

今天,逛某乎,刷到这个问题,如何在使用 request-promise 时使用代理? 实际不难,我们一起来看看。 如何解决这个问题,我们要知道request-promise 是一个基于Promise的HTTP请求库,可以简化Node.js中发送HTTP…

vue中如何查看组件有哪些函数与变量

在开发的过程中,经常用到他人的框架,特别是开源框架比如element,uniapp等。其中就涉及到框架里对应的组件。而组件里又有哪些内置的函数,我们通常是去查官方文档。然后很多的时候需求的多样性,要改的地方也是不一样的,…

二、vue-cli项目搭建

系列文章: vue实战(商城后台管理系统):http://t.csdnimg.cn/f6Fqa vue.js :http://t.csdnimg.cn/mljxv 目录 系列文章: vue实战(商城后台管理系统):http://t.csdnimg.cn/f6Fqa vue…

制作图片马:二次渲染(upload-labs第17关)

代码分析 $im imagecreatefromjpeg($target_path);在本关的代码中这个imagecreatefromjpeg();函数起到了将上传的图片打乱并重新组合。这就意味着在制作图片马的时候要将木马插入到图片没有被改变的部分。 gif gif图的特点是无损,我们可以对比上传前后图片的内容…

综合实验---Web环境搭建

题目: 服务器IP地址规划:client:12.0.0.12/24,网关服务器:ens36:12.0.0.1/24、ens33:192.168.10.1/24,Web1:192.168.10.10/24,Web2:192.168.10.20/24&#xf…

php集成修改数据库的字段

1.界面效果 2.代码 <?phpecho <form action"" method"post"><label for"table">表名:</label><input type"text" id"table" name"table"><br><div id"fieldsContaine…

Spring 面试题及答案整理,最新面试题

Spring框架中的Bean生命周期是什么&#xff1f; Spring框架中的Bean生命周期包含以下关键步骤&#xff1a; 1、实例化Bean&#xff1a; 首先创建Bean的实例。 2、设置属性值&#xff1a; Spring框架通过反射机制注入属性。 3、调用BeanNameAware的setBeanName()&#xff1a…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的输电线路设备检测系统(深度学习+UI界面+Python代码+训练数据集)

摘要&#xff1a;本篇博客详细介绍了如何运用深度学习构建一个先进的输电线路设备检测系统&#xff0c;并附上了完整的实现代码。该系统利用了最新的YOLOv8算法作为其核心&#xff0c;同时也对之前版本的YOLOv7、YOLOv6、YOLOv5进行了性能比较&#xff0c;包括但不限于mAP&…

电脑工作电压是多少你要看看光驱电源上面标的输入电压范围

要确定电脑的工作电压&#xff0c;必须查看电源上标注的输入电压范围。 国内法规规定民用220V电压范围为10%-15%&#xff0c;也就是说通信220V电压正常范围为187--242V&#xff0c;供电设备一般为180V。 --250V电压范围&#xff0c;即正常情况下电脑电源电压不低于187V即可工作…

土地利用数据分类过程教学/土地利用分类/遥感解译/土地利用获取来源介绍/地理数据获取

本篇主要介绍如何对影像数据进行分类解译&#xff0c;及过程教学&#xff0c;示例数据下载链接&#xff1a;数据下载链接 一、背景介绍 土地是人类赖以生存与发展的重要资源和物质保障&#xff0c;在“人口&#xff0d;资源&#xff0d;环境&#xff0d;发展&#x…

【node】初识node以及fs操作,path操作以及http操作(一)

1、不同浏览器使用不同的javaScript引擎 chrome > v8 Firefox > OdinMonkey&#xff08;奥丁猴&#xff09; safri > JSCore IE浏览器>Chakra(查克拉) 2、node是一个基于chrome v8引擎的javaScript运行环境 浏览器是JavaScript的前端运行环境&#xff0c;…

社交媒体的未来图景:探索Facebook的数字化之旅

随着科技的迅猛发展&#xff0c;数字化社交已经成为了我们日常生活中不可或缺的一部分。在这个数字化时代&#xff0c;社交媒体平台扮演着重要角色&#xff0c;其中Facebook作为社交媒体的先锋&#xff0c;不断探索创新之路&#xff0c;引领着数字化社交的未来发展。本文将深入…

力扣:链表篇章

1、链表 链表是一种通过指针串联在一起的线性结构&#xff0c;每一个节点由两部分组成&#xff0c;一个是数据域一个是指针域&#xff08;存放指向下一个节点的指针&#xff09;&#xff0c;最后一个节点的指针域指向null&#xff08;空指针的意思&#xff09;。 2、链表的类…

图【数据结构】

文章目录 图的基本概念邻接矩阵邻接表图的遍历BFSDFS 图的基本概念 图是由顶点集合及顶点间的关系组成的一种数据结构 顶点和边&#xff1a;图中结点称为顶点 权值:边附带的数据信息 路径 &#xff1a; 简单路径 和 回路&#xff1a; 子图&#xff1a;设图G {V, E}和图G1…

考研C语言复习进阶(1)

目录 1. 数据类型介绍 1.1 类型的基本归类&#xff1a; 2. 整形在内存中的存储 2.1 原码、反码、补码 2.2 大小端介绍 3. 浮点型在内存中的存储 ​编辑 1. 数据类型介绍 前面我们已经学习了基本的内置类型&#xff1a; char //字符数据类型 short //短整型 int /…