深入理解全连接层:从线性代数到 PyTorch 中的 nn.Linear 和 nn.Parameter

文章目录

这篇文章会从基础的一个数学概念到对应的代码实现,你将了解到:

  • 为什么nn.Parameter()接受 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)作为参数?
  • 为什么不是torch.matmul(self.weight, x) + self.bias
  • 如何使用torch.matmul()@F.linear() 去等价地实现nn.Linear()的输出。

数学概念(全连接层,线性层)

线性变化是数学中一个基础的概念,它描述了如何通过线性变换将输入映射到输出。在线性代数中,线性变化通常表示为矩阵乘法。在神经网络中,线性层的核心就是实现这样的矩阵运算。

数学公式:

给定一个输入向量 x ∈ R n \mathbf{x} \in \mathbb{R}^n xRn 和一个输出向量 y ∈ R m \mathbf{y} \in \mathbb{R}^m yRm,线性变化通过矩阵 W ∈ R m × n \mathbf{W} \in \mathbb{R}^{m \times n} WRm×n 和偏置项 b ∈ R m \mathbf{b} \in \mathbb{R}^m bRm 进行变换,其公式为:
y = W x + b \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} y=Wx+b

  • W \mathbf{W} W:是权重矩阵,维度为 m × n m \times n m×n,它决定了输入向量如何线性变换到输出空间;
  • x \mathbf{x} x:是输入向量,维度为 n n n,表示特征数据;
  • b \mathbf{b} b:是偏置向量,维度为 m m m,用来调整线性变换的输出;
  • y \mathbf{y} y:是输出向量,维度为 m m m,是变换后的结果。

例子:

如果输入向量 x \mathbf{x} x 有 3 个特征,输出向量 y \mathbf{y} y 有 2 个特征,则权重矩阵 W \mathbf{W} W 的形状为 2 × 3 2 \times 3 2×3。假设:
W = [ 1 2 3 4 5 6 ] , x = [ 1 2 3 ] , b = [ 0 1 ] \mathbf{W} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}, \quad \mathbf{x} = \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}, \quad \mathbf{b} = \begin{bmatrix} 0 \\ 1 \end{bmatrix} W=[142536],x= 123 ,b=[01]
线性变换计算为:
y = W x + b = [ 1 2 3 4 5 6 ] [ 1 2 3 ] + [ 0 1 ] = [ 14 32 ] + [ 0 1 ] = [ 14 33 ] \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 33 \end{bmatrix} y=Wx+b=[142536] 123 +[01]=[1432]+[01]=[1433]
矩阵运算过程:
[ 1 2 3 4 5 6 ] [ 1 2 3 ] = [ ( 1 × 1 ) + ( 2 × 2 ) + ( 3 × 3 ) ( 4 × 1 ) + ( 5 × 2 ) + ( 6 × 3 ) ] = [ 14 32 ] \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} = \begin{bmatrix} (1 \times 1) + (2 \times 2) + (3 \times 3) \\ (4 \times 1) + (5 \times 2) + (6 \times 3) \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} [142536] 123 =[(1×1)+(2×2)+(3×3)(4×1)+(5×2)+(6×3)]=[1432]

nn.Linear()

nn.Linear() 会自动创建一个权重矩阵(Weight)和偏置项(Bias),并将它们应用到输入上。

代码示例:

import torch
import torch.nn as nn# 定义一个输入为3,输出为2的线性层
linear_layer = nn.Linear(3, 2)# 打印权重矩阵和偏置项
print("权重矩阵 W:")
print(linear_layer.weight)print("偏置项 b:")
print(linear_layer.bias)# 模拟输入向量
input_vector = torch.tensor([1.0, 2.0, 3.0])
output_vector = linear_layer(input_vector)
print("输出向量 y:")
print(output_vector)

image-20240912221728559

在这里,nn.Linear(3, 2) 创建了一个 2×3 的权重矩阵和一个 2 维的偏置向量。通过 linear_layer(input_vector),可以直接获得输入向量经过线性变换后的输出。

nn.Parameter()

在 PyTorch 中,nn.Linear() 自动处理了权重和偏置项的初始化和更新,但有时你可能希望对这些参数自定义一些操作,比如 LoRA。这时,我们可以使用 nn.Parameter() 来自定义权重和偏置,其实 nn.Linear() 本身就是使用的nn.Parameter(),感兴趣的话可以看官方源码。

以自定义线性层为例:

class CustomLinearLayer(nn.Module):def __init__(self, input_dim, output_dim):super(CustomLinearLayer, self).__init__()# 使用 nn.Parameter 手动定义权重和偏置self.weight = nn.Parameter(torch.randn(output_dim, input_dim))self.bias = nn.Parameter(torch.randn(output_dim))def forward(self, x):# 手动实现线性变换 y = Wx + breturn torch.matmul(x, self.weight.T) + self.bias# 使用自定义的线性层
custom_layer = CustomLinearLayer(3, 2)
output = custom_layer(input_vector)
print(output)

image-20240912222625609

在看完代码后,你可能会产生两个疑惑:

Q

1. 为什么 self.weight 的权重矩阵 shape 使用 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)而不是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)?

这正是我写这篇博客的原因,接下来我们详细解释这个问题。

让我们重新使用 in_features \text{in\_features} in_features out_features \text{out\_features} out_features来重现之前的数学定义:

对于输入向量 x ∈ R in_features \mathbf{x} \in \mathbb{R}^{\text{in\_features}} xRin_features,全连接层的输出为:

y = W x + b \mathbf{y} = W\mathbf{x} + \mathbf{b} y=Wx+b

其中:

  • W ∈ R out_features × in_features W \in \mathbb{R}^{\text{out\_features} \times \text{in\_features}} WRout_features×in_features 是权重矩阵,
  • b ∈ R out_features \mathbf{b} \in \mathbb{R}^{\text{out\_features}} bRout_features 是偏置项。

在线性变换中,输入向量 x \mathbf{x} x 的维度是 in_features \text{in\_features} in_features,而输出向量 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features。根据矩阵乘法的规则,要将输入 x \mathbf{x} x 映射到输出 y \mathbf{y} y,权重矩阵 W W W 的形状应该是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features),因为矩阵乘法中 W x W\mathbf{x} Wx的维度要求是:

( out_features × in_features ) × ( in_features × 1 ) = ( out_features × 1 ) (\text{out\_features} \times \text{in\_features}) \times (\text{in\_features} \times 1) = (\text{out\_features} \times 1) (out_features×in_features)×(in_features×1)=(out_features×1)

这保证了输出 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features

如果权重矩阵的形状是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features),矩阵乘法的维度将不匹配,无法实现线性变换。

现在是不是感觉清晰了?不要 nn.Linear(in_feature, out_feature) 用多了就将权重矩阵当作是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)遗忘了线性代数的概念,数学才是这一切的基石。

2. 为什么是torch.matmul(x, self.weight.T) + self.bias 而不是torch.matmul(self.weight, x) + self.bias?

主要原因还是在于 输入张量 x 的形状矩阵乘法规则

一般来说,模型的输入 x 实际上并不是 ( in_features , 1 ) (\text{in\_features}, 1) (in_features,1),而是 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features),而权重矩阵 self.weight 的形状是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)​,我们需要实现的线性变换是:
y = W x + b y = W x + b y=Wx+b
根据矩阵乘法规则,第一个矩阵的列数必须等于第二个矩阵的行数,这意味着我们不能直接计算 torch.matmul(self.weight, x),因为这样会导致维度不匹配:

  • self.weight 形状为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)x 形状为 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features)
  • torch.matmul(self.weight, x) 的维度计算规则将要求 x 的形状为 ( in_features , batch_size ) (\text{in\_features}, \text{batch\_size}) (in_features,batch_size),但这与模型的输入不匹配。

因此,正确的矩阵乘法应该是 torch.matmul(x, self.weight.T),其中 self.weight.T 表示 self.weight 的转置矩阵,此时的形状为 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)

这样,torch.matmul(x, self.weight.T) 的维度计算为:

( batch_size , in_features ) × ( in_features , out_features ) = ( batch_size , out_features ) (\text{batch\_size}, \text{in\_features}) \times (\text{in\_features}, \text{out\_features}) = (\text{batch\_size}, \text{out\_features}) (batch_size,in_features)×(in_features,out_features)=(batch_size,out_features)

这就得到了正确的输出形状 ( batch_size , out_features ) (\text{batch\_size}, \text{out\_features}) (batch_size,out_features)

3. 为什么不直接设置self.weight = nn.Parameter(torch.randn(input_dim, output_dim))

这样不就可以不转置直接使用torch.matmul(x, self.weight)了吗?的确如此,或许是因为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features) 对于矩阵运算 W x W\mathbf{x} Wx 来讲更符合直觉吧。

计算过程的细分:torch.matmul() vs @ 运算符

在 PyTorch 中,torch.matmul() 用于实现矩阵乘法,而 @ 是其简洁的符号形式,是 Python 的语法糖,二者在功能上是等价的。

示例代码:

import torch# 定义权重矩阵 W 和输入向量 input_vector
W = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
input_vector = torch.tensor([1.0, 2.0, 3.0])# 使用 torch.matmul 实现矩阵乘法
result1 = torch.matmul(W, input_vector)# 使用 @ 运算符
result2 = W @ input_vectorprint("使用 torch.matmul 计算的结果:")
print(result1)print("使用 @ 运算符计算的结果:")
print(result2)

结果:

image-20240912233355773

使用 F.linear()

PyTorch 提供了 F.linear() 作为函数式接口,它与 nn.Linear() 类似,但不需要创建一个线性层对象。F.linear() 可以接受线性层的权重和偏置作为输入。

示例代码:

import torch.nn.functional as F# 使用 F.linear 进行线性变换
output = F.linear(input_vector, linear_layer.weight, linear_layer.bias)
print(output)

image-20240912233501651

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/423294.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue的缓存组件 | 详解KeepAlive

引言 在Vue开发中,我们经常需要处理大量的组件渲染和销毁操作,这可能会影响应用的性能和用户体验。而Vue的KeepAlive组件提供了一种简便的方式来优化组件的渲染和销毁流程,通过缓存已经渲染的组件来提升应用的性能。 本文将详细介绍Vue的Ke…

即插即用篇 | YOLOv10 引入矩形自校准模块RCM | ECCV 2024

本改进已同步到YOLO-Magic框架! 语义分割是许多应用的重要任务,但要在有限的计算成本下实现先进性能仍然非常具有挑战性。在本文中,我们提出了CGRSeg,一个基于上下文引导的空间特征重建的高效且具有竞争力的分割框架。我们精心设计了一个矩形自校准模块,用于空间特征重建和…

经典RNA-seq分析流程1

RNA-seq分析有很多流程, 一般都是上游linux工具获取表达矩阵数据,然后就可以使用下游R包进行处理了,要么是差异DEG表达gene等分析; 因为下游分析其实R包是明确的,毕竟有很多生信分析教程,但是上游的linux…

无人机之处理器篇

无人机的处理器是无人机系统的核心部件之一,它负责控制无人机的飞行、数据处理、任务执行等多个关键功能。以下是对无人机处理器的详细解析: 一、处理器类型 无人机中使用的处理器主要包括以下几种类型: CPU处理器:CPU是无人机的…

神经网络多层感知器异或问题求解-学习篇

多层感知器可以解决单层感知器无法解决的异或问题 首先给了四个输入样本,输入样本和位置信息如下所示,现在要学习一个模型,在二维空间中把两个样本分开,输入数据是个矩阵,矩阵中有四个样本,样本的维度是三维…

Unity全面取消Runtime费用 安装游戏不再收版费

Unity宣布他们已经废除了争议性的Runtime费用,该费用于2023年9月引入,定于1月1日开始收取。Runtime费用起初是打算根据使用Unity引擎安装游戏的次数收取版权费。2023年9月晚些时候,该公司部分收回了计划,称Runtime费用只适用于订阅…

[数据集][目标检测]车窗状态检测车窗开关检测数据集VOC+YOLO格式299张3类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):299 标注数量(xml文件个数):299 标注数量(txt文件个数):299 标注类别…

应用程序已被 Java 安全阻止:Java 安全中的添加的例外站点如何对所有用户生效

如题:应用程序已被 Java 安全阻止,如下图所示: 在寻找全局配置的时候花了一个上午的时间,到处搜解决方法,都不可行。最后还是参考官方的文档配置好了。如果你碰到了同样的问题,这篇文章一定可以帮到你。 环…

论文阅读:AutoDIR Automatic All-in-One Image Restoration with Latent Diffusion

论文阅读:AutoDIR: Automatic All-in-One Image Restoration with Latent Diffusion 这是 ECCV 2024 的一篇文章,利用扩散模型实现图像恢复的任务。 Abstract 这篇文章提出了一个创新的 all-in-one 的图像恢复框架,融合了隐扩散技术&#x…

【重学 MySQL】二十八、SQL99语法新特性之自然连接和 using 连接

【重学 MySQL】二十八、SQL99语法新特性之自然连接和 using 连接 自然连接(NATURAL JOIN)USING连接总结 SQL99语法在SQL92的基础上引入了一些新特性,其中自然连接(NATURAL JOIN)和USING连接是较为显著的两个特性。 自…

《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》P84

更正卷积与相关微课中互相关运算动画中的索引。 1-D correlation rectwave 禹晶、肖创柏、廖庆敏《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》 禹晶、肖创柏、廖庆敏《数字图像处理》资源二维码

性能测试【Locust】基本使用介绍

一.前言 Locust是一款易于使用的分布式负载测试工具,基于事件驱动,使用轻量级执行单元(如协程)来实现高并发。 二.基本使用 以下是Locust性能测试使用的一个基础Demo示例,该示例有安装Locust、编写测试脚本、启动测…

三方共建 | 网络安全运营中心正式揭牌成立

9月3日,广州迎来了一场网络安全领域的盛事。悦学科技、聚铭网络、微步在线联合打造的7x24小时网络安全运营中心(以下简称“中心”)正式成立,并在现场举行了庄重而热烈的揭牌仪式。众多行业专家、企业代表齐聚一堂,共同…

MPP数据库之SelectDB

SelectDB 是一个高性能、云原生的 MPP(大规模并行处理)数据库,旨在为分析型数据处理场景提供快速、弹性和高效的解决方案。它专为处理大规模结构化和半结构化数据设计,常用于企业级业务分析、实时分析和决策支持。 SelectDB 是在…

实习项目|苍穹外卖|day9

实战作业。 用户端新增功能 1. 查询历史订单 接口设计 返回的是orderorderdetails(那我这里就先查order,再根据order_id查) 分页 pageHelper的使用: //controller相关函数GetMapping("/historyOrders")ApiOperati…

【GBase 8c V5_3.0.0 分布式数据库常用几个SQL】

1.检查应用连接数 以管理员用户 gbase,登录数据库主节点。 接数据库,并执行如下 SQL 语句查看连接数。 SELECT count(*) FROM (SELECT pg_stat_get_backend_idset() AS backendid) AS s;2.查看空闲连接 查看空闲(state 字段为”idle”)且长时间没有更…

AI问答-Vue实例属性/实例方法:$refs、$emit、$attrs、$props、$data...

一、本文简介 在Vue.js中,$ 符号通常用于表示Vue实例或组件上的内置属性和方法,这些被称为“实例属性”或“实例方法”。以下是一些常见的以$开头的Vue实例属性和方法 1.1、实例属性 序号实例属性解释1$dataVue实例的数据对象,用于存储组件…

Linux - 探秘/proc/sys/net/ipv4/ip_local_port_range

文章目录 Pre概述默认值及其意义评估需求如何调整临时修改永久修改测试和验证 修改的潜在影响 Pre Linux - 探秘 Linux 的 /proc/sys/vm 常见核心配置 计划: 简要解释 /proc/sys/net/ipv4/ip_local_port_range 文件的功能和作用。介绍该文件的默认值及其影响。说明…

ChatGPT: A Simulator Who Passed the Turing Test?

文章目录 引言Introduction:Applications:Discussion:Future Outlook:汉语翻译 引言 本文是一篇英语课前pre,简单介绍了ChatGPT的功能,内容一般,希望能帮到你。🙂 Introduction: Standing at the intersection of natural lan…

Failed building wheel for opencv-python-headless

Failed building wheel for opencv-python-headless 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英杰,211科班出身,就职于医疗科技公司,热衷分享知识,武汉城市开发者社区主理人…