【深度学习】Pytorch 系列教程(十四):PyTorch数据结构:6、模块(Module):前向传播

        

目录

一、前言

二、实验环境

三、PyTorch数据结构

0、分类

1、张量(Tensor)

2、张量操作(Tensor Operations)

3、变量(Variable)

4、数据集(Dataset)

5、数据加载器(DataLoader)

6、模块(Module)


一、前言

ChatGPT:

        PyTorch是一个开源的机器学习框架,广泛应用于深度学习领域。它提供了丰富的工具和库,用于构建和训练各种类型的神经网络模型。下面是PyTorch的一些详细介绍:

  1. 动态计算图:PyTorch使用动态计算图的方式进行计算,这意味着在运行时可以动态地定义、修改和调整计算图,使得模型的构建和调试更加灵活和直观。

  2. 强大的GPU加速支持:PyTorch充分利用GPU进行计算,可以大幅提升训练和推理的速度。它提供了针对GPU的优化操作和内存管理,使得在GPU上运行模型更加高效。

  3. 自动求导:PyTorch内置了自动求导的功能,可以根据定义的计算图自动计算梯度。这简化了反向传播算法的实现,使得训练神经网络模型更加便捷。

  4. 大量的预训练模型和模型库:PyTorch生态系统中有许多预训练的模型和模型库可供使用,如TorchVision、TorchText和TorchAudio等,可以方便地加载和使用这些模型,加快模型开发的速度。

  5. 高级抽象接口:PyTorch提供了高级抽象接口,如nn.Modulenn.functional,用于快速构建神经网络模型。这些接口封装了常用的神经网络层和函数,简化了模型的定义和训练过程。

  6. 支持分布式训练:PyTorch支持在多个GPU和多台机器上进行分布式训练,可以加速训练过程,处理大规模的数据和模型。

        总体而言,PyTorch提供了一个灵活而强大的平台,使得深度学习的研究和开发更加便捷和高效。它的简洁的API和丰富的功能使得用户可以快速实现复杂的神经网络模型,并在各种任务中取得优秀的性能。

二、实验环境

        本系列实验使用如下环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib

关于配置环境问题,可参考前文的惨痛经历:

三、PyTorch数据结构

0、分类

  • Tensor(张量):Tensor是PyTorch中最基本的数据结构,类似于多维数组。它可以表示标量、向量、矩阵或任意维度的数组。
  • Tensor的操作:PyTorch提供了丰富的操作函数,用于对Tensor进行各种操作,如数学运算、统计计算、张量变形、索引和切片等。这些操作函数能够高效地利用GPU进行并行计算,加速模型训练过程。
  • Variable(变量):Variable是对Tensor的封装,用于自动求导。在PyTorch中,Variable会自动跟踪和记录对其进行的操作,从而构建计算图并支持自动求导。在PyTorch 0.4.0及以后的版本中,Variable被废弃,可以直接使用Tensor来进行自动求导。
  • Dataset(数据集):Dataset是一个抽象类,用于表示数据集。通过继承Dataset类,可以自定义数据集,并实现数据加载、预处理和获取样本等功能。PyTorch还提供了一些内置的数据集类,如MNIST、CIFAR-10等,用于方便地加载常用的数据集。
  • DataLoader(数据加载器):DataLoader用于将Dataset中的数据按批次加载,并提供多线程和多进程的数据预读功能。它可以高效地加载大规模的数据集,并支持数据的随机打乱、并行加载和数据增强等操作。
  • Module(模块):Module是PyTorch中用于构建模型的基类。通过继承Module类,可以定义自己的模型,并实现前向传播和反向传播等方法。Module提供了参数管理、模型保存和加载等功能,方便模型的训练和部署。

1、张量(Tensor

        

PyTorch数据结构:1、Tensor(张量):维度(Dimensions)、数据类型(Data Types)_QomolangmaH的博客-CSDN博客https://blog.csdn.net/m0_63834988/article/details/132909219​编辑https://blog.csdn.net/m0_63834988/article/details/132909219icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/132909219

2、张量操作(Tensor Operations)

3、变量(Variable)

PyTorch数据结构:3、变量(Variable)介绍_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/132923358?spm=1001.2014.3001.5501

4、数据集(Dataset)

PyTorch数据结构:4、数据集(Dataset)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/132923697?spm=1001.2014.3001.5501

5、数据加载器(DataLoader)

PyTorch数据结构:5、数据加载器(DataLoader)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/132924381?spm=1001.2014.3001.5502

6、模块(Module)

        在PyTorch中,构建模型的一般步骤是创建一个自定义的模型类,并将其作为Module的子类。在模型类的初始化方法中(通常为__init__),可以定义模型的层结构和参数。通过重写前向传播方法(forward),可以定义数据在模型中的流动方式。

        Module类提供了许多有用的功能:

  •  参数管理:Module类自动跟踪和管理模型中的可学习参数。这些参数可以通过parameters()方法获得,方便进行优化和参数更新操作。

  • 模型保存和加载:可以使用torch.save()方法将整个模型保存到文件中,以便在以后重新加载和使用。加载模型时,可以使用torch.load()方法加载保存的模型参数。

  • 自动求导:Module类内部的操作默认支持自动求导。在前向传播过程中,PyTorch会自动构建计算图,并记录每个操作的梯度计算方式。这样,在反向传播过程中,可以自动计算和更新模型的参数梯度。 

        除了上述功能,Module类还提供了一些其他方法和属性,例如train()eval()方法用于切换模型的训练和评估模式,to()方法用于将模型移动到指定的设备(如CPU或GPU)等。

import torch
import torch.nn as nn# 自定义模型类
class CustomModel(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(CustomModel, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 创建一个自定义模型的实例
input_size = 10
hidden_size = 20
num_classes = 2model = CustomModel(input_size, hidden_size, num_classes)# 使用模型进行前向传播
input_data = torch.randn(5, input_size)
output = model(input_data)print(output)# 保存模型
torch.save(model.state_dict(), 'model.pth')# 加载模型
loaded_model = CustomModel(input_size, hidden_size, num_classes)
loaded_model.load_state_dict(torch.load('model.pth'))
  • 定义了一个自定义的模型类CustomModel,继承自nn.Module
  • 在模型的初始化方法中定义了模型的层结构,并在前向传播方法中定义了数据的流动方式。
  • 通过创建模型实例model,我们可以使用它进行前向传播。
  • 保存模型参数到文件中,并在之后加载参数来恢复模型。

  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/134300.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CDH集群部署

文章目录 1. 资源准备2. 部署 Mariadb 数据库3. 安装CM服务4. 安装数据节点5. 登录CM系统 1. 资源准备 准备好CDH安装包资源,官方网站下载需要账号,如果没有账号可以去网上到处搜搜。主要涉及到的资源有: cloudera-manager-servercloudera-m…

虹科分享 | 软件供应链攻击如何工作?如何评估软件供应链安全?

说到应用程序和软件,关键词是“更多”。在数字经济需求的推动下,从简化业务运营到创造创新的新收入机会,企业越来越依赖应用程序。云本地应用程序开发更是火上浇油。然而,情况是双向的:这些应用程序通常更复杂&#xf…

SPF9139全力适配ios16与鸿蒙3.0,超实用数据提取、分析、恢复能力UP!

​ 如今,群聊已成为人们必不可少的沟通窗口 家人群,好友群,班级群 粉丝群,交友群,工作群 …… 各类群聊铺天盖地般涌来的同时 也有一些群聊沦为了 赌博、传播淫秽视频、发表不当言论 等违法犯罪行为滋生之地 与…

开源日报 0825 | 简化开发过程,提升Swift应用性能的扩展工具库

OpenZeppelin/openzeppelin-contracts Stars: 22.8k License: MIT OpenZeppelin Contracts 是一个用于安全智能合约开发的库。它建立在社区验证过的代码基础上,具有以下主要功能: 实现了 ERC20 和 ERC721 等标准。灵活的基于角色的权限控制方案。可重…

Jenkins :添加node权限获取凭据、执行命令

拥有Jenkins agent权限的账号可以对node节点进行操作,通过添加不同的node可以让流水线项目在不同的节点上运行,安装Jenkins的主机默认作为master节点。 1.Jenkins 添加node获取明文凭据 通过添加node节点,本地监听ssh认证,选则凭…

iOS系统暗黑模式

系统暗黑模式: 暗黑模式颜色适配: 方式1: Assets配置:在Assets中配置好颜色后,可以通过colorNamed: 放大获取到动态颜色。 方式2:代码配置,通过代码colorWithDynamicProvider: 可以看出来生成…

移动端H5封装一个 ScrollList 横向滚动列表组件,实现向左滑动

效果&#xff1a; 1.封装组件&#xff1a; <template><div class"scroll-list"><divclass"scroll-list-content":style"{ background, color, fontSize: size }"ref"scrollListContent"><div class"scroll…

arcgis实现矢量数据的局部裁剪

目录 环境介绍&#xff1a; 操作任务&#xff1a; 方法一&#xff1a;通过arcgis直接选取要素并保存出来 方法二&#xff1a;通过已知的经纬范围&#xff0c;掩膜获取该范围内的矢量数据 环境介绍&#xff1a; Windows操作系统、arcgis10.8 操作任务&#xff1a; 从整体的…

3ds max文件打包?max插件CG Magic一键打包整起!

3ds max文件如何打包&#xff1f;这个问题&#xff0c;小编听到不少网友的提问&#xff01; 今天CG Magic小编来和大家聊聊&#xff0c;文件更高效的操作&#xff0c;如何打包处理呢&#xff1f; 3DMAX这款软件的受众群体是比较高的&#xff0c;在工作方便的同时&#xff0c;…

uniapp h5 echarts 打包后图表点击失效/及其他失效

文章目录 期望效果实际效果环境引入echarts方式解决方法&#xff1a;注意 原因多说一句在h5打包的时候将 history 改为 hash 不然在浏览器打开后刷新会404 期望效果 实际效果 环境 pc端 window11 hbuilderx版本 3.8.12 echarts版本 5.4.3 引入echarts方式 npm install echar…

二,手机硬件参数介绍和校验算法

系列文章目录 第一章 安卓aosp源码编译环境搭建 第二章 手机硬件参数介绍和校验算法 第三章 修改安卓aosp代码更改硬件参数 第四章 编译定制rom并刷机实现硬改(一) 第五章 编译定制rom并刷机实现硬改(二) 第六章 不root不magisk不xposed lsposed frida原生修改定位 第七章 安卓…

Css实现右上角飘带效果

效果图&#xff1a; 源码&#xff1a; <!DOCTYPE html> <html><head><meta charset"utf-8"><title></title><style type"text/css">*{margin: 0 auto;padding: 0;}.wrap {/* 设置宽高 */width: 350px;height: …

Vue 安装与创建第一Docker的项目

1. 下载nodejs 并且安装 https://nodejs.org/en 2. 打开命令窗口&#xff0c;验证是否安装成功 C:\Users\Harry>node -v v18.16.0C:\Users\Harry>npm -v 9.5.1 3. 安装Vue CLI C:\Users\Harry>npm install -g vue/cli 经过不算漫长的等待&#xff0c;你的Vue CLI就装…

CSS——grid网格布局的基本使用

网格布局在实现页面自适应&#xff0c;大屏可视化中常常使用&#xff0c;在这篇博客里&#xff0c;记录一下网格布局的基本使用。 参考文档&#xff1a;网格布局_菜鸟教程 文章目录 1. 体会grid的自适应性2. grid-template-arr配置网格行列3. 网格单位fr与repeat()简写属性值4…

SpringMVC中的JSR303与拦截器的使用

一&#xff0c;JSR303的概念 JSR303是Java中的一个标准&#xff0c;用于验证和校验JavaBean对象的属性的合法性。它提供了一组用于定义验证规则的注解&#xff0c;如NotNull、Min、Max等。在Spring MVC中&#xff0c;可以使用JSR303注解对请求参数进行校验。 1.2 为什么要使用J…

腾讯mini项目-【指标监控服务重构】2023-08-11

今日待办 使用watermill框架替代当前的base_runner框架 a. 参考官方提供的sarama kafka Pub/Sub(https://github.com/ThreeDotsLabs/watermill-kafka/)实现kafka-go(https://github.com/segmentio/kafka-go)的Pub/Sub&#xff08;sarama需要cgo&#xff0c;会导致一些额外的镜像…

nginx配置指南

nginx.conf配置 找到Nginx的安装目录下的nginx.conf文件&#xff0c;该文件负责Nginx的基础功能配置。 配置文件概述 Nginx的主配置文件(conf/nginx.conf)按以下结构组织&#xff1a; 配置块功能描述全局块与Nginx运行相关的全局设置events块与网络连接有关的设置http块代理…

windows11安装安卓程序的坑

首先&#xff0c;百度一下&#xff0c;网上大把教程&#xff0c;比如&#xff1a; 【2023最新版】Windows11家庭版&#xff1a;安卓子系统&#xff08;WSA&#xff09;安装及使用教程【全网最详细】_QomolangmaH的博客-CSDN博客 写的就比较详细了&#xff0c;仅供参考。 一些…

C【动态内存管理】

1. 为什么存在动态内存分配 int val 20;//在栈空间上开辟四个字节 char arr[10] {0};//在栈空间上开辟10个字节的连续空间 2. 动态内存函数的介绍 2.1 malloc&#xff1a;stdlib.h void* malloc (size_t size); int* p (int*)malloc(40); #include <stdlib.h> #incl…

9.3.5网络原理(应用层HTTP/HTTPS)

一.HTTP: 1. HTTP是超文本传输协议,除了传输字符串,还可以传输图片,字体,视频,音频. 2. 3.HTTP协议报文格式:a.首行,b.请求头(header),c.空行(相当于一个分隔符,分隔了header和body),d.正文(body). 4. 5.URL:唯一资源描述符(长度不限制). a. b.注意:查询字符串(query stri…