Pytorch深度学习教程_9_nn模块构建神经网络

欢迎来到《深度学习保姆教程》系列的第九篇!在前面的几篇中,我们已经介绍了Python、numpy及pytorch的基本使用,进行了梯度及神经网络的实践并学习了激活函数和激活函数,在上一个教程中我们学习了优化算法。今天,我们将开始使用pytorch构建我们自己的神经网络。

欢迎订阅专栏进行系统学习:

深度学习保姆教程_tRNA做科研的博客-CSDN博客


目录

 1.理解nn模块:

(1)使用 nn.Sequential 创建神经网络

(2)自定义神经网络

2 创建神经网络层

(1)线性层(全连接层)‌

(2)卷积层

(3)池化层

(4)循环层

(5)其他层类型

(6)将层组合成神经网络

3 构建顺序模型

(1)循环神经网络(RNNs)‌

(2)长短期记忆网络(LSTMs)‌

(3)门控循环单元(GRUs)‌

4 自定义神经网络模块

(1)理解自定义模块的需求

创建自定义模块

将自定义模块纳入神经网络

 总结


 

PyTorch 的 nn 模块提供了构建神经网络的高级接口。它封装了层、损失函数和优化算法,使得构建和训练复杂模型变得更加容易。

 1.理解nn模块:

nn 模块提供了一系列类和函数,用于构建各种神经网络组件:

  • 层(Layers)‌:定义对输入数据执行的计算操作(例如,Linear、Conv2d、RNN)。
  • 损失函数(Loss functions)‌:量化预测输出和实际输出之间的差异(例如,CrossEntropyLoss、MSELoss)。
  • 容器(Containers)‌:将层组织成顺序、并行或更复杂的结构(例如,Sequential、ModuleList、ModuleDict)。
  • 初始化(Initialization)‌:初始化模型参数(例如,kaiming_normal_、xavier_uniform_)。

(1)使用 nn.Sequential 创建神经网络

nn.Sequential 容器是一种创建前馈神经网络的简单方法。它定义了一个模块的线性堆栈。

import torch
import torch.nn as nn# 定义一个简单的神经网络
model = nn.Sequential(nn.Linear(input_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, output_size)
)

(2)自定义神经网络

虽然 nn.Sequential 很方便,但您可以通过继承 nn.Module 来创建更复杂的架构。这允许您自定义前向传播逻辑,并对模型结构拥有更多控制权。

import torch
import torch.nn as nn# 定义一个名为MyModel的神经网络模型类,继承自nn.Module
class MyModel(nn.Module):# 构造函数,初始化模型参数def __init__(self, input_size, hidden_size, output_size):# 调用父类nn.Module的构造函数进行初始化super(MyModel, self).__init__()# 定义第一个全连接层(线性层),输入大小为input_size,输出大小为hidden_sizeself.fc1 = nn.Linear(input_size, hidden_size)# 定义第二个全连接层(线性层),输入大小为hidden_size,输出大小为output_sizeself.fc2 = nn.Linear(hidden_size, output_size)# 前向传播函数,定义数据如何通过网络def forward(self, x):# 对第一个全连接层的输出应用ReLU激活函数x = torch.relu(self.fc1(x))# 将经过激活后的输出传递给第二个全连接层x = self.fc2(x)# 返回最终的输出return x# 实例化MyModel模型,传入输入层、隐藏层和输出层的大小
model = MyModel(input_size, hidden_size, output_size)

 

2 创建神经网络层

神经网络层是处理信息的基本组件。它们对输入数据执行计算,将其转换为适合后续层的表示形式。让我们探索一些常见的层类型。

(1)线性层(全连接层)

线性层,也称为全连接层,将一层中的每个神经元连接到下一层中的每个神经元。它们执行矩阵乘法后加上偏置。

import torch.nn as nnlinear_layer = nn.Linear(in_features=10, out_features=20)

(2)卷积层

卷积层对于处理网格状数据(如图像)至关重要。它们应用滤波器来提取特征。

import torch.nn as nnconv_layer = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)

(3)池化层

池化层降低输入的维度,同时保留重要信息。

import torch.nn as nnmax_pool = nn.MaxPool2d(kernel_size=2, stride=2)
avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

(4)循环层

循环层处理序列数据。它们维护内部状态以捕捉来自先前步骤的信息。

import torch.nn as nnrnn = nn.RNN(input_size=10, hidden_size=20, num_layers=1)
lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=1)
gru = nn.GRU(input_size=10, hidden_size=20, num_layers=1)

(5)其他层类型

  • 归一化层(Normalization layers)‌:归一化输入数据(例如,BatchNorm、LayerNorm)。
  • Dropout 层(Dropout layers)‌:通过随机丢弃神经元来防止过拟合。
  • 嵌入层(Embedding layers)‌:将分类数据转换为密集向量。

(6)将层组合成神经网络

多个层组合在一起可以创建复杂的架构。例如,卷积神经网络通常由卷积层、池化层和全连接层组成。

import torch.nn as nn# 定义一个名为MyCNN的神经网络模型类,继承自nn.Module
class MyCNN(nn.Module):def __init__(self):super(MyCNN, self).__init__()# 定义第一个卷积层self.conv1 = nn.Conv2d(3, 16, 3, padding=1)# 定义最大池化层self.pool = nn.MaxPool2d(2, 2)# 定义第一个全连接层self.fc1 = nn.Linear(16 * 7 * 7, 120)# 定义第二个全连接层self.fc2 = nn.Linear(120, 84)# 定义第三个全连接层self.fc3 = nn.Linear(84, 10)def forward(self, x):# 应用卷积、ReLU激活函数和池化x = self.pool(torch.relu(self.conv1(x)))# 将特征图展平成一维向量x = torch.flatten(x, 1)# 应用第一个全连接层和ReLU激活函数x = torch.relu(self.fc1(x))# 应用第二个全连接层和ReLU激活函数x = torch.relu(self.fc2(x))# 应用第三个全连接层x = self.fc3(x)return x

关键考虑因素

  • 层的深度和宽度(Layer depth and width)‌:尝试不同数量的层和神经元。
  • 超参数调整(Hyperparameter tuning)‌:优化诸如核大小、步幅和填充等参数。
  • 计算效率(Computational efficiency)‌:考虑不同层的计算成本。

 

3 构建顺序模型

顺序模型旨在处理具有时间或空间顺序的数据,例如时间序列数据、文本和音频。它们捕捉序列中元素之间的依赖关系。

理解顺序数据

  • 时间序列数据(Time series data)‌:在特定时间点记录的观测值。
  • 文本数据(Text data)‌:单词或字符的序列。
  • 音频数据(Audio data)‌:音频样本的序列。

(1)循环神经网络(RNNs)

RNNs 是处理顺序数据的基础架构。它们在网络中引入循环,允许信息跨时间步长持续存在。

  • 隐藏状态(Hidden state)‌:维护关于过去输入的信息。
  • 梯度消失问题(Vanishing gradient problem)‌:难以学习长期依赖关系。
import torch.nn as nn# 简单的 RNN
rnn = nn.RNN(input_size=10, hidden_size=20, num_layers=1)

(2)长短期记忆网络(LSTMs)

LSTMs 通过引入记忆单元和门来解决梯度消失问题。

  • 记忆单元(Memory cell)‌:存储长时间的信息。
  • 门(Gates)‌:控制信息流入和流出记忆单元的流动。
import torch.nn as nnlstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=1)

(3)门控循环单元(GRUs)

GRUs 是 LSTMs 的简化版本,参数更少。

  • 更新门(Update gate)‌:控制保留多少之前的隐藏状态。
  • 重置门(Reset gate)‌:控制忘记多少过去的信息。
import torch.nn as nngru = nn.GRU(input_size=10, hidden_size=20, num_layers=1)

挑战和注意事项

  • 梯度消失/爆炸(Vanishing/exploding gradients)‌:可能阻碍训练,特别是对于长序列。
  • 计算成本(Computational cost)‌:RNNs 训练起来可能计算成本较高。
  • 过拟合(Overfitting)‌:防止模型记住训练数据。

 

4 自定义神经网络模块

虽然 PyTorch 的 nn 模块提供了一套丰富的预建层,但在某些情况下,需要创建针对特定问题域或架构创新的自定义组件。

(1)理解自定义模块的需求

  • 特定问题的操作(Problem-specific operations)‌:某些任务可能需要标准层未涵盖的操作。
  • 架构实验(Architectural experimentation)‌:自定义模块允许灵活地探索新架构。
  • 性能优化(Performance optimization)‌:手工实现有时可能更高效。

创建自定义模块

要创建自定义模块,您需要继承 torch.nn.Module 并实现 forward 方法。该方法定义了对输入张量执行的计算。

import torch
import torch.nn as nn# 定义一个名为MyCustomLayer的自定义神经网络层类,继承自nn.Module
class MyCustomLayer(nn.Module):def __init__(self, in_features, out_features):super().__init__()# 初始化权重参数,使用标准正态分布随机初始化self.weight = nn.Parameter(torch.randn(in_features, out_features))# 初始化偏置参数,使用零初始化self.bias = nn.Parameter(torch.zeros(out_features))def forward(self, x):# 执行线性变换:x @ weight + biasreturn x @ self.weight + self.bias

将自定义模块纳入神经网络

一旦定义了自定义模块,您就可以像使用其他层一样在神经网络架构中使用它。

model = nn.Sequential(MyCustomLayer(10, 20),nn.ReLU(),nn.Linear(20, 5)
)

高级自定义

  • 参数共享(Parameter sharing)‌:创建在不同层之间共享参数的模块。
  • 动态架构(Dynamic architectures)‌:根据输入数据构建具有可变结构的模型。
  • 混合模型(Hybrid models)‌:结合自定义模块与预训练层。

 总结

通过掌握自定义模块的创建,您可以解锁设计高度专业化和创新的神经网络架构的潜力。nn 模块是一个强大的工具,但将其与对神经网络概念的理解相结合,才能构建有效的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/38886.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

EasyRTC嵌入式音视频通信SDK:WebRTC技术下的硬件与软件协同演进,开启通信新时代

在当今数字化时代,智能设备的普及和人们对实时通信需求的不断增长,推动了嵌入式音视频通信技术的快速发。EasyRTC嵌入式音视频通信SDK凭借其独特的技术特点和应用优势,在嵌入式设备和多平台实时通信领域脱颖而出。 1、轻量级设计与高性能 Ea…

Uthana,AI 3D角色动画生成平台

Uthana是什么 Uthana 是专注于3D角色动画生成的AI平台。平台基于简单的文字描述、参考视频或动作库搜索,快速为用户生成逼真的动画,支持适配任何骨骼结构的模型。Uthana 提供风格迁移、API集成和定制模型训练等功能,满足不同用户需求。平台提…

Python:多线程创建的语法及步骤

线程模块:import threading 线程类Thread参数:group(线程组) target:执行的目标的任务名 args:以元组的方式给执行任务进行传参 *args可以传任意多个参数 kwargs以字典方式给执行任务传参 name:线程名 步骤&…

Jupyter Notebook 常用命令(自用)

最近有点忘记了一些常见命令,这里就记录一下,懒得找了。 文章目录 一、文件操作命令1. %cd 工作目录2. %pwd 显示路径3. !ls 列出文件4. !cp 复制文件5. !mv 移动或重命名6. !rm 删除 二、代码调试1. %time 时间2. %timeit 平均时长3. %debug 调试4. %ru…

快速入手-基于Django的Form和ModelForm操作(七)

1、Form组件 2、ModelForm操作 3、给前端表单里在django里添加class相关属性值 4、前端 5、后端form 新增数据处理 6、更新数据处理

【Linux系统】Linux权限讲解!!!超详细!!!

目录 Linux文件类型 区分方法 文件类型 Linux用户 用户创建与删除 用户之间的转换 su指令 普通用户->超级用户(root) 超级用户(root) ->普通用户 普通账户->普通账户 普通用户的权限提高 sudo指令 注: Linux权限 定义 权限操作 1、修改文…

剑指小米特斯拉:秦L EV上市11.98万起

3月23日,比亚迪王朝网推出全新中级纯电轿车秦L EV,价格区间为11.98万-13.98万元,瞬间火爆市场。 依托e平台3.0 Evo技术赋能,秦L EV以“国潮设计、智能座舱、越级空间、高效安全、高阶智驾”五大核心优势,直击年轻用户痛…

嵌入式学习(31)-Lora模块A39C-T400A30D1a

一、概述 A39C-T400A30D1a是一款410~490MHz,1W,具有高稳定性,工业级的无线串口模块。LORA扩频调制,实测传输距离最远可达10K米。该模块具备数据广播、数据监听、定点传输、主从模式、自动中继、定点唤醒等传输方式,支…

使用__attribute__((at(addr))) 固定变量到指定 Flash 地址

文章目录 一、代码示例:将变量固定到 Flash 0x08001000二、__attribute__((at(addr))) 的作用三、__attribute__((at(addr))) 可能导致的问题四、运行时修改 Flash 存储的变量五、在 GCC(STM32CubeIDE)中实现同样功能 在嵌入式开发中&#xf…

vmware虚拟机快照、克隆、迁移区别说明

一、快照 1.1 快照概念 记录了虚拟机在某个特定时间点的状态(软件部署、网络配置、照片备份、游戏存档等) 1.2快照用途 可以在需要时轻松地恢复虚拟机到快照创建时的状态。 备份和恢复:快速备份虚拟机状态的方法可以在数据丢失或损坏时快速恢复虚拟机到先前的状态。测试和…

面试常问系列(一)-神经网络参数初始化

一、背景 说到参数初始化,先提一下大家常见的两个概念梯度消失和梯度爆炸。 (一)、梯度消失:深层网络的“静默杀手” 定义: 在反向传播过程中,梯度值随着网络层数增加呈指数级衰减,最终趋近…

使用CSS3实现炫酷的3D翻转卡片效果

使用CSS3实现炫酷的3D翻转卡片效果 这里写目录标题 使用CSS3实现炫酷的3D翻转卡片效果项目介绍技术要点分析1. 3D空间设置2. 核心CSS属性3. 布局和定位 实现难点和解决方案1. 3D效果的流畅性2. 卡片内容布局3. 响应式设计 性能优化建议浏览器兼容性总结 项目介绍 在这个项目中…

AI Agent开发大全第七课-个人如何申请到靠谱的AI

前言 前面几个课程我们做了一些AI基础知识的铺垫,不要小看基础知识,这些基础知识往往是一些正在从事AI开发的工作者们都没有深入去了解的。 其实这就好比简历上写熟练使用mySql,而实际mySql里那些精妙的参数和设置以及一些底层真的都知道吗? 所以我特别强调基础得打造,…

什么是网络准入?十种常见的网络准入解决方案分享!

在数字化转型的浪潮中,企业网络的边界日益模糊,数据安全与访问控制成为了企业IT管理的核心挑战之一。OneNAC网络准入系统,作为新一代网络安全解决方案的佼佼者,凭借其强大的功能特性和灵活性,在众多网络准入控制&#…

Jetpack Compose 选项卡控件实现

这里写目录标题 介绍主体解释 介绍 实现选项卡控件 主体 import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.…

Java 大视界 -- Java 大数据在智慧文旅旅游目的地营销与品牌传播中的应用(150)

💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也…

使用密码连接Redis服务的两种方式

说明:本文介绍连接需要密码的Redis服务的两种方式 方式一 连接时,携带密码,如下: redis-cli -a [密码]如下: 有两个问题: 密码直接放在命令里,可通过 history 找到,不安全&#x…

搭建React简单项目

一、项目构建 目录结构: 安装脚手架 npm install -g create-react-app // or yarn add -g create-react-app 一、项目版本 1、react:"^18.3.1"; 2、react-router-dom:"^6.23.1"; 3、项目创…

知识库已上线

目录 知识库上线了加入知识库注册账号切换租户加入租户找到知识库点击申请等待管理员审核通过后,点击去后台可以开始创作了创建我们的第一个知识库点击详情进入创作页面,创建我们的第一篇知识 发布知识将我们的知识库变更为公开状态发布知识等待管理员审…