【PyTorch函数解析】einsum的用法示例

一、前言

einsum 是一个非常强大的函数,用于执行张量(Tensor)运算。它的名称来源于爱因斯坦求和约定(Einstein summation convention),在PyTorch中,einsum 可以方便地进行多维数组的操作和计算。

在Transfomer中,einsum用的非常多,比如使用 einsum 实现自注意力机制中注意力权重的获取,也就是Q和K的内积:

  • Q(Query):形状为 (batch_size, seq_len, d_k)

  • K(Key):形状为 (batch_size, seq_len, d_k)

import torch
import torch.nn.functional as FQ = torch.randn(2, 10, 64)  # (batch_size, seq_len, d_k)
K = torch.randn(2, 10, 64)  # (batch_size, seq_len, d_k)# (batch_size, seq_len, seq_len)
attention_scores = torch.einsum('bqd,bkd->bqk', Q, K) / torch.sqrt(torch.tensor(64.0))
# (batch_size, seq_len, seq_len)   
attention_weights = F.softmax(attention_scores, dim=-1)  

二、常见用法示例

2.1 向量点积

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.einsum('i,i->', a, b)
print(result)  # 输出 32

这里,'i,i->' 表示对向量 a 和 b 进行点积操作,其中 i 是索引表示,-> 之后为空表示求和。

2.2 矩阵乘法

A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.einsum('ij,jk->ik', A, B)
print(result)  # 输出 tensor([[19, 22], [43, 50]])

这里,'ij,jk->ik' 表示矩阵乘法,其中 i 和 k 是结果的维度,j 是求和维度。

2.3 批量矩阵乘法

A = torch.randn(2, 3, 4)
B = torch.randn(2, 4, 5)
result = torch.einsum('bij,bjk->bik', A, B)

这里,'bij,bjk->bik' 表示对批量的矩阵进行乘法运算。

解释:

bij,bjk分别是A和B的3个维度,用字符串的形式指代。

为什么最后得到的是bik呢?这个和线性代数的矩阵运算规则有关系。

矩阵乘法规则:

  • 给定矩阵 A 的形状为 (m,n)

  • 给定矩阵 B 的形状为 (n,p)

  • 矩阵乘法 A×B 的结果矩阵 C 的形状为 (m,p)

在矩阵乘法中,结果矩阵的每个元素 Cik 是通过 A 的第 i 行和 B 的第 k 列的对应元素相乘并求和得到的,即:

C_{ik}=\sum_{j=1}^nA_{ij}\cdot B_{jk}

计算过程:

1. 匹配批次维度 (b)

  • 对于每个批次,独立进行矩阵乘法运算。

2. 求和维度 (j):

  • j 是两个张量中共同的维度,根据线性代数中的矩阵乘法规则,需要对 j 维度进行求和。

3. 保留和产生的维度:

  • i 来自 A,表示保留 A 的第一个维度。

  • k 来自 B,表示保留 B 的第二个维度。

经过上述分析,einsum 的结果保留了 b(批次维度)、i(来自 A 的第一个维度)和 k(来自 B 的第二个维度)。因此,结果张量的形状为 (batch_size, seq_len_i, seq_len_k),也就是 bik。

同样,延伸到4维计算的话。

torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

首先,假设 queries 和 keys 的形状为:

  • queries: (batch_size, seq_len_q, num_heads, head_dim)

  • keys: (batch_size, seq_len_k, num_heads, head_dim)

用具体变量名表示:

  • n: batch_size,批次大小。

  • q: seq_len_q,查询序列的长度。

  • k: seq_len_k,键序列的长度。

  • h: num_heads,多头注意力中的头数。

  • d: head_dim,每个头的维度。

1. 匹配批次维度 (n) 和头部维度 (h):

  • 批次大小和头部数量在两个输入张量中都是相同的,保持不变。

2. 求和维度 (d):

  • d 表示每个头的维度。在 queries 和 keys 中,d 都是最后一个维度,对这个维度进行点积运算后求和。

3. 保留和产生的维度:

  • q 来自 queries,表示查询序列的长度。

  • k 来自 keys,表示键序列的长度。

所以最后是nhqk。

2.4 转置操作

A = torch.tensor([[1, 2, 3], [4, 5, 6]])
result = torch.einsum('ij->ji', A)
print(result)  # 输出 tensor([[1, 4], [2, 5], [3, 6]])

这里,'ij->ji' 表示将矩阵进行转置操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/363526.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GPT-5:AI新纪元的领航者,多维度的审视与准备

一、引言:GPT-5与AI的多维演进 GPT-5作为AI领域的里程碑式突破,不仅仅代表了技术的飞跃,更预示着社会、文化以及经济等多个层面的深刻变革。从技术的角度看,GPT-5代表着AI在自然语言处理领域的最新高度;而从更宽广的视…

Linux双网卡默认路由的metric设置不正确,导致SSH连接失败问题定位

测试环境 VMware虚拟机 RockyLinux 9 x86_64 双网卡:eth0(访问外网): 10.206.216.92/24; eth1(访问内网) 192.168.1.4/24 问题描述 虚拟机重启后,SSH连接失败,提示"Connection time out",重启之前SSH连接还是正常的…

音视频入门基础:H.264专题(8)——H.264官方文档的描述符

音视频入门基础:H.264专题系列文章: 音视频入门基础:H.264专题(1)——H.264官方文档下载 音视频入门基础:H.264专题(2)——使用FFmpeg命令生成H.264裸流文件 音视频入门基础&…

Java代码基础算法练习-删除有序数组中的重复项-2024.05.07

任务描述: 有一批同学需要计算各自的出生年月是否闰年。请使用算法计算出他们的出生年份是否闰年。 解决思路: 如果要一次性输出结果,就是先输入数字n,确定首先循环几次,在每次循环中进行闰年判断操作,每次…

RK3588/算能/Nvidia智能盒子:[AI智慧油站」,以安全为基,赋能精准经营

2021年9月,山东省应急管理厅印发了关于《全省危险化学品安全生产信息化建设与应用工作方案(2021-2022 年)》的通知,要求全省范围内加快推进危险化学品安全生产信息化、智能化建设与应用工作,建设完善全省危险化学品安全…

遥感数据并行运算(satellite remote sensing data parallell processing)

文章内容仅用于自己知识学习和分享,如有侵权,还请联系并删除 :) 之前不太会用,单纯想记录一下,后面或许还会用到 1. 教程 [1] Pleasingly Parallel Programming: link 1.1 处理器,核和线程 …

山东水利职业学院空调集控系统案例,节能减排、降低维护成本

日常在公共办公场所使用空调时,人离开办公室空调依然开着,由于适用空调的不良行为导致能源浪费。良好的学习环境是保持学生好的学习状态的前提条件,让学生在炎热的夏季都能享受到舒适的室内空气环境是很重要的,对空调集中管理&…

ASUS/华硕天选Air 2021 FX516P系列 原厂win10系统

安装后恢复到您开箱的体验界面,带原机所有驱动和软件,包括myasus mcafee office 奥创等。 最适合您电脑的系统,经厂家手调试最佳状态,性能与功耗直接拉满,体验最原汁原味的系统。 原厂系统下载网址:http:…

Spring Clude 是什么?

目录 认识微服务 单体架构 集群和分布式架构 集群和分布式 集群和分布式区别和联系 微服务架构 分布式架构&微服务架构 微服务的优势和带来的挑战 微服务解决方案- Spring Cloud 什么是 Spring Cloud Spring Cloud 版本 Spring Cloud 和 SpringBoot 的关系 Sp…

深度学习 —— 1.单一神经元

深度学习初级课程 1.单一神经元2.深度神经网络3.随机梯度下降法4.过拟合和欠拟合5.剪枝、批量标准化6.二分类 前言 本套课程仍为 kaggle 课程《Intro to Deep Learning》,仍按之前《机器学习》系列课程模式进行。前一系列《Keras入门教程》内容,与本系列…

STM32 IWDG(独立看门狗)

1 IWDG简介 STM32有两个看门狗:一个是独立看门狗(IWDG),另外一个是窗口看门狗。独立看门狗也称宠物狗,窗口看门狗也称警犬。本文主要分析独立看门狗的功能和它的应用。 独立看门狗用通俗一点的话来解释就是一个12位的…

在Ubuntu上安装VNC服务器教程

Ubuntu上安装VNC服务器方法:按照root安装TeactVnc,随后运行vncserver输入密码,安装并打开RickVNC客户端,输入服务器的IP,最后连接输入密码即可。 VNC或虚拟网络计算,可让您连接到远程Linux / Unix服务器的…

udp Socket组播 服务器

什么是组播 组播也可以称之为多播这也是 UDP 的特性之一。组播是主机间一对多的通讯模式,是一种允许一个或多个组播源发送同一报文到多个接收者的技术。组播源将一份报文发送到特定的组播地址,组播地址不同于单播地址,它并不属于特定某个主机…

laravel的日志使用说明

文章目录 了解系统的默认支持多个通道时它们的关系如何使用驱动默认日志是同步的 了解系统的默认支持 Laravel 日志基于「 通道 」和 「 驱动 」的。那么这个通道是干嘛的?驱动又是干嘛的? 通道 : 1.它表示了某种日志格式化的方式&#xff…

云动态摘要 2024-06-28

给您带来云厂商的最新动态,最新产品资讯和最新优惠更新。 最新优惠与活动 [新客专享]WeData 限时特惠 腾讯云 2024-06-21 数据分类分级管理,构建数据安全屏障 ,仅需9.9元! 云服务器ECS试用产品续用 阿里云 2024-04-14 云服务器…

游戏AI的创造思路-技术基础-深度学习(3)

继续填坑,本篇介绍深度学习中的长短期记忆网络~~~~ 目录 3.3. 长短期记忆网络(LSTM) 3.3.1. 什么是长短期记忆网络 3.3.2. 形成过程与运行原理 3.3.2.1. 细胞状态与门结构 3.3.2.2. 遗忘门 3.3.2.3. 输入门 3.3.2.4. 细胞状态更新 3.…

Younger 数据集:人工智能生成神经网络

设计和优化神经网络架构通常需要广泛的专业知识,从手工设计开始,然后进行手动或自动化的精细化改进。这种依赖性成为快速创新的重要障碍。认识到从头开始自动生成神经网络架构的复杂性,本文引入了Younger,这是一个开创性的数据集&…

机器学习python实践——关于管道模型Pipeline和网格搜索GridSearchCV的一些个人思考

最近在利用python跟着指导书进行机器学习的实践,在实践中使用到了Pipeline类方法和GridSearchCV类方法,并且使用过程中发现了一些问题,所以本文主要想记录并分享一下个人对于这两种类方法的思考,如果有误,请见谅&#…

Kubernetes 容器编排技术

Kubernetes 容器编排 前言 知识扩展 早在 2015 年 5 月,Kubernetes 在 Google 上的搜索热度就已经超过了 Mesos 和 Docker Swarm,从那儿之后更是一路飙升,将对手甩开了十几条街,容器编排引擎领域的三足鼎立时代结束。 目前,AWS…

蚂蚁- 定存

一:收益变动&&收益重算 1.1: 场景组合 1: 澳门元个人活期,日终余额大于0,当日首次、本周本月非首次系统结息,结息后FCDEPCORE_ASYN_CMD_JOB捞起进行收益计算 【depc_account_revenue_detail】收益日 > 【depc_accoun…