PyTorch使用教程(6)一文讲清楚torch.nn和torch.nn.functional的区别

torch.nn torch.nn.functional 在 PyTorch 中都是用于构建神经网络的重要组件,但它们在设计理念、使用方式和功能上存在一些显著的区别。以下是关于这两个模块的详细区别:

1. 继承方式与结构

torch.nn

  • torch.nn 中的模块大多数是通过继承 torch.nn.Module 类来实现的。这些模块都是 Python 类,包含了神经网络的各种层(如卷积层、全连接层等)和其他组件(如损失函数、优化器等)。
  • torch.nn 中的模块可以包含可训练参数,如权重和偏置,这些参数在训练过程中会被优化。

torch.nn.functional

  • torch.nn.functional 中的函数是直接调用的,无需实例化。这些函数通常用于执行各种非线性操作、损失函数计算、激活函数应用等。
  • torch.nn.functional 中的函数没有可训练参数,它们只是执行操作并返回结果。

2. 实现方式与调用方式

torch.nn

  • torch.nn 中的模块是基于面向对象的方法实现的。开发者需要创建类的实例,并在类的 forward 方法中定义数据的前向传播路径。
  • torch.nn 中的模块通常需要先创建模型实例,再将输入数据传入模型中进行前向计算。

torch.nn.functional

  • torch.nn.functional 中的函数是基于函数式编程实现的。它们提供了灵活的接口,允许开发者以函数调用的方式轻松定制和扩展神经网络架构。
  • torch.nn.functional 中的函数可以直接调用,只需要将输入数据传入函数中即可进行前向计算。

3. 使用场景与优势

torch.nn

  • torch.nn 更适合用于定义有状态的模块,如包含可训练参数的层。
  • 当定义具有变量参数的层时(如卷积层、全连接层等),torch.nn 会帮助初始化好变量,并且模型类本身就是 nn.Module 的实例,看起来会更加协调统一。
  • torch.nn 可以结合 nn.Sequential 来简化模型的构建过程。

torch.nn.functional

  • torch.nn.functional 中的函数相比 torch.nn 更偏底层,封装性不高但透明度很高。开发者可以在其基础上定义出自己想要的功能。
  • 使用 torch.nn.functional 可以更方便地进行函数组合、复用等操作,适合那些喜欢使用函数式编程风格的开发者。当激活函数只需要在前向传播中使用时,使用 torch.nn.functional 中的激活函数会更加简洁。

4. 权重与参数管理

torch.nn

  • torch.nn 中的模块会自动管理权重和偏置等参数,这些参数可以通过 model.parameters() 方法获取,并用于优化算法的训练。

torch.nn.functional

  • torch.nn.functional 中的函数不直接管理权重和偏置等参数。如果需要使用这些参数,开发者需要在函数外部定义并初始化它们,然后将它们作为参数传入函数中。

5.举例说明

例子1:定义卷积层

使用 torch.nn

import torch.nn as nnclass MyConvNet(nn.Module):def __init__(self):super(MyConvNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)def forward(self, x):x = self.conv1(x)return x# 实例化模型
model = MyConvNet()# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_conv_net(input_tensor, weight, bias=None):output_tensor = F.conv2d(input_tensor, weight, bias=bias, stride=1, padding=1)return output_tensor# 定义卷积核的权重和偏置
weight = nn.Parameter(torch.randn(16, 1, 3, 3))
bias = nn.Parameter(torch.randn(16))# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = my_conv_net(input_tensor, weight, bias)

在这个例子中,使用 torch.nn 定义了一个包含卷积层的模型类,而使用 torch.nn.functional 则是通过函数直接进行卷积操作。注意在使用 torch.nn.functional 时,需要手动定义和传递卷积核的权重和偏置。

例子2:应用激活函数

使用 torch.nn

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.relu = nn.ReLU()def forward(self, x):x = self.relu(x)return x# 实例化模型
model = MyModel()# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_model(input_tensor):output_tensor = F.relu(input_tensor)return output_tensor# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = my_model(input_tensor)

在这个例子中,使用 torch.nn 定义了一个包含 ReLU 激活函数的模型类,而使用 torch.nn.functional 则是通过函数直接应用 ReLU 激活函数。

例子3:定义和计算损失

使用 torch.nn

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = nn.Linear(10, 2)def forward(self, x):x = self.linear(x)return x# 实例化模型
model = MyModel()# 定义损失函数
criterion = nn.CrossEntropyLoss()# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()# 前向传播和计算损失
output_tensor = model(input_tensor)
loss = criterion(output_tensor, target)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_model(input_tensor):output_tensor = torch.matmul(input_tensor, weight.t()) + biasreturn output_tensor# 定义权重和偏置
weight = nn.Parameter(torch.randn(10, 2))
bias = nn.Parameter(torch.randn(2))# 定义损失函数
criterion = nn.CrossEntropyLoss()# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()# 前向传播和计算损失
output_tensor = my_model(input_tensor)
loss = criterion(output_tensor, target)

在这个例子中,使用 torch.nn 定义了一个包含全连接层的模型类,并使用了 torch.nn 中的损失函数来计算损失。而使用 torch.nn.functional 则是通过函数直接进行线性变换,并使用 torch.nn 中的损失函数来计算损失。注意在使用 torch.nn.functional 时,需要手动定义和传递权重和偏置。

6. 小结

torch.nn 和 torch.nn.functional 在定义神经网络组件、应用激活函数和计算损失等方面存在显著的区别。torch.nn 提供了一种面向对象的方式来构建模型,而 torch.nn.functional 则提供了一种更灵活、更函数式的方式来执行相同的操作。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/3285.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数字小偷:2025年全面防护指南

在今天的数字时代,金钱已不再局限于传统银行和实体店铺,而是转移到网上银行和电子商务平台上。而随着这一变化,网络犯罪也从现实世界的抢劫演变成了数字世界中的“数字扒窃”。这意味着,几乎每个商业实体,无论大小&…

RV1126+FFMPEG推流项目(6)视频码率及其码率控制方式

视频从采集到编码再到线程获取编码后的数据,已经全部说完。接下来继续来说应该比较重要的,和视频相关的。就是码率。 视频码率及其码率控制方式 一、什么是码率? 视频码率是指在单位时间内传输的视频数据量,通常以 kbps&#x…

python管理工具:conda部署+使用

python管理工具:conda部署使用 一、安装部署 1、 下载 - 官网下载: https://repo.anaconda.com/archive/index.html - wget方式: wget -c https://repo.anaconda.com/archive/Anaconda3-2023.03-1-Linux-x86_64.sh2、 安装 在conda文件的…

深入理解 D3.js 力导向图:原理、调参与应用

目录 前言1. D3.js 力导向图的核心原理1.1 物理模拟与数值积分器1.2 力导向图的物理模型 2. D3.js 力导向图的主要调参项2.1 向心力(d3.forceCenter)2.2 碰撞检测(d3.forceCollide)2.3 弹簧力(d3.forceLink&#xff09…

redis实现限流

令牌桶逻辑 计算逻辑: 代码: import redis.clients.jedis.Jedis; import redis.clients.jedis.JedisPool;/*** ClassName RedisRateLimiterTokenBucket* Description TODO* Author zhang zhengdong* DATE 2025/1/17 20:22* Version 1.0*/ public class…

Golang Gin系列-2:搭建Gin 框架环境

开始网络开发之旅通常是从选择合适的工具开始的。在这个全面的指南中,我们将引导你完成安装Go编程语言和Gin框架的过程,Gin框架是Go的轻量级和灵活的web框架。从设置Go工作空间到将Gin整合到项目中,本指南是高效而强大的web开发路线图。 安装…

Maven在Win10上的安装教程

诸神缄默不语-个人CSDN博文目录 这个文件可以跟我要,也可以从官网下载: 第一步:解压文件 第二步:设置环境变量 在系统变量处点击新建,输入变量名MAVEN_HOME,变量值为解压路径: 在系统变…

51c大模型~合集106

我自己的原文哦~ https://blog.51cto.com/whaosoft/13115290 #GPT-5、 Opus 3.5为何迟迟不发 新猜想:已诞生,被蒸馏成小模型来卖 「从现在开始,基础模型可能在后台运行,让其他模型能够完成它们自己无法完成的壮举——就像一个老…

SpringBoot+Vue小区智享物业管理系统(高质量源码,可定制,提供文档,免费部署到本地)

作者简介:✌CSDN新星计划导师、Java领域优质创作者、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交流。✌ 主要内容:🌟Java项目、Python项目、前端项目、PHP、ASP.NET、人工智能…

Json转换类型报错问题:java.lang.Integer cannot be cast to java.math.BigDecimal

Json转换类型报错问题:java.lang.Integer cannot be cast to java.math.BigDecimal 小坑规避指南 小坑规避指南 项目中遇到json格式转换成Map,已经定义了Map的key和value的类型,但是在遍历Map取值的时候出现了类型转换的报错问题&#xff08…

在Playwright中使用PO模式

1.新建项目 安装库 npm init -y npm install -D playwright npm install -D playwright/test npm install typescript ts-node types/node npx playwright install 项目目录 2.编写代码 package.json {"name": "pom_playwright","version": …

Web渗透测试之伪协议与SSRF服务器请求伪装结合? 能产生更多的效果

目录 ssrf漏洞利用{危害} 伪协议介绍 Gopher协议简介 SSRF攻击中常用协议 file协议 file协议数据格式: dict协议 dict协议数据格式: Gopher攻击Redis redis的协议数据流格式: 简单示例: 使用gopher协议写入定时任务 …

【论文阅读笔记】人工智能胃镜在盲区检测和自主采图中的应用

作者:李夏/吴练练/于红刚 小结 盲区检测的意思,实际上在算法的需求定义上,就是部位识别。   胃肠镜检查中,按照不同的规范,有不同应该观察到的地方。当医生知道哪些部位比较容易出病灶的情况下,就容易忽…

python之二维几何学习笔记

一、概要 资料来源《机械工程师Python编程:入门、实战与进阶》安琪儿索拉奥尔巴塞塔 2024年6月 点和向量:向量的缩放、范数、点乘、叉乘、旋转、平行、垂直、夹角直线和线段:线段中点、离线段最近的点、线段的交点、直线交点、线段的垂直平…

AI编程工具使用技巧——通义灵码

活动介绍通义灵码1. 理解通义灵码的基本概念示例代码生成 2. 使用明确的描述示例代码生成 3. 巧妙使用注释示例代码生成 4. 注意迭代与反馈原始代码反馈后生成优化代码 5. 结合生成的代码进行调试示例测试代码 其他功能定期优化生成的代码合作与分享结合其他工具 总结 活动介绍…

国产编辑器EverEdit - 复制为RTF

1 复制为RTF 1.1 应用背景 在写产品手册或者其他文档时,可能会用到要将产品代码以样例的形式放到文档中,一般的文本编辑器拷贝粘贴到Word中也就是普通文本,没有语法着色,这样感观上不是太好,为了让读者的感观更好一点…

Python毕业设计选题:基于python的酒店推荐系统_django+hadoop

开发语言:Python框架:djangoPython版本:python3.7.7数据库:mysql 5.7数据库工具:Navicat11开发软件:PyCharm 系统展示 管理员登录 管理员功能界面 用户管理 酒店客房管理 客房类型管理 客房预定管理 用户…

YoloV10改进策略:Neck层改进|EFC,北理提出的适用小目标的特征融合模块|即插即用

论文信息 论文题目:A Lightweight Fusion Strategy With Enhanced Interlayer Feature Correlation for Small Object Detection 论文链接:https://ieeexplore.ieee.org/abstract/document/10671587 官方github:https://github.com/nuliweixiao/EFC 研究贡献 为了解决上…

Re78 读论文:GPT-4 Technical Report

诸神缄默不语-个人CSDN博文目录 诸神缄默不语的论文阅读笔记和分类 论文全名:GPT-4 Technical Report 官方博客:GPT-4 | OpenAI appendix懒得看了。 文章目录 1. 模型训练过程心得2. scaling law3. 实验结果减少风险 1. 模型训练过程心得 模型结构还…

Linux中安装mysql8,很详细

一、查看系统glibc版本号,下载对应版本的MySQL 1、查看glibc版本号办法 方法一:使用ldd命令 在终端中输入ldd --version命令,然后按下回车键。这个命令会显示系统中安装的glibc版本号。例如,如果输出信息是ldd (GNU libc) 2.31&a…