Pytorch常用的函数(九)torch.gather()用法

Pytorch常用的函数(九)torch.gather()用法

torch.gather() 就是在指定维度上收集value。

torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

一句话概括 gather 操作就是:根据 index ,在 inputdim 维度上收集 value

1、举例直观理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、我们有index_tensor如下
>>> index_tensor = torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
)	# 3、我们通过torch.gather()函数获取out_tensor
>>> out_tensor = torch.gather(input_tensor, dim=1, index=index_tensor)
tensor([[[ 0,  1,  2,  3],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])

我们以out_tensor中[0,1,0]=8为例,解释下如何利用dim和index,从input_tensor中获得8。

在这里插入图片描述

根据上图,我们很直观的了解根据 index ,在 inputdim 维度上收集 value的过程。

  • 假设 inputindex 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可;
  • 但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报错 ;
  • 所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。在上面示例中,由于dim=1,那么我们就替换索引列表第1个值,即[i,dim,k],因此由原来的[0,1,0]替换为[0,2,0]后,再去input_tensor中取值。
  • pytorch官方文档的写法如下,同一个意思。
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

2、反推法再理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、假设我们要得到out_tensor如下
>>> out_tensor
tensor([[[ 0,  1,  2,  3],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])# 3、如何知道dim 和 index_tensor呢? 
# 首先,我们要记住:out_tensor的shape = index_tensor的shape# 从 output_tensor 的第一个位置开始:
# 此时[i, j, k]一样,看不出来 dim 应该是多少
output_tensor[0, 0, :] = input_tensor[0, 0, :] = 0
# 同理可知,此时index都为0
output_tensor[0, 0, 1] = input_tensor[0, 0, 1] = 1
output_tensor[0, 0, 2] = input_tensor[0, 0, 2] = 2
output_tensor[0, 0, 3] = input_tensor[0, 0, 3] = 3# 我们从下一行的第一个位置开始:
# 这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2
output_tensor[0, 1, 0] = input_tensor[0, 2, 0] = 8
# 同理可知,此时index都为2
output_tensor[0, 1, 1] = input_tensor[0, 2, 1] = 9
output_tensor[0, 1, 2] = input_tensor[0, 2, 2] = 10
output_tensor[0, 1, 3] = input_tensor[0, 2, 3] = 11# 根据上面推导我们易知dim=1,index_tensor为:
>>> index_tensor = torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
)	

3、实际案例

在大神何凯明MAE模型(Masked Autoencoders Are Scalable Vision Learners)源码中,多次使用了torch.gather() 函数。

  • 论文链接:https://arxiv.org/pdf/2111.06377
  • 官方源码:https://github.com/facebookresearch/mae

在MAE中根据预设的掩码比例(paper 中提倡的是 75%),使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder,另一部分mask 掉。采样25%作为unmasked tokens过程中,使用了torch.gather() 函数。

# models_mae.pyimport torchdef random_masking(x, mask_ratio=0.75):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimlen_keep = int(L * (1 - mask_ratio))  # 计算unmasked的片数# 利用0-1均匀分布进行采样,避免潜在的【中心归纳偏好】noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sample【核心代码】ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]# 利用torch.gather()从源tensor中获取25%的unmasked tokensx_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restoreif __name__ == '__main__':x = torch.arange(64).reshape(1, 16, 4)random_masking(x)
# x模拟一张图片经过patch_embedding后的序列
# x相当于input_tensor
# 16是patch数量,实际上一般为(img_size/patch_size)^2 = (224 / 16)^2 = 14*14=196
# 4是一个patch中像素个数,这里只是模拟,实际上一般为(in_chans * patch_size * patch_size = 3*16*16 = 768)
>>> x = torch.arange(64).reshape(1, 16, 4) 
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11],[12, 13, 14, 15],[16, 17, 18, 19], # 4[20, 21, 22, 23],[24, 25, 26, 27],[28, 29, 30, 31],[32, 33, 34, 35],[36, 37, 38, 39],[40, 41, 42, 43], # 10[44, 45, 46, 47],[48, 49, 50, 51], # 12[52, 53, 54, 55], # 13[56, 57, 58, 59],[60, 61, 62, 63]]])
# dim=1, index相当于index_tensor
>>> index
tensor([[[10, 10, 10, 10],[12, 12, 12, 12],[ 4,  4,  4,  4],[13, 13, 13, 13]]])# x_masked(从源tensor即x中,随机获取25%(4个patch)的unmasked tokens)     
>>> x_masked相当于out_tensor
tensor([[[40, 41, 42, 43],[48, 49, 50, 51],[16, 17, 18, 19],[52, 53, 54, 55]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/322458.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Tomcat端口占用解决方案

Windows操作系统 出现这种情况: Error was Port already in use :40001;nested exception is :java.net.BindException: Address already in use : JVM_Bind; 步骤1:按下winR键,输入cmd 步骤2:输入以下命令 netstat …

基于MPPT最大功率跟踪和SVPWM的光伏三相并网逆变器simulink建模与仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 基于MPPT最大功率跟踪和SVPWM的光伏三相并网逆变器simulink建模与仿真。包括PV模块,MPPT模块,SVPWM模块,电网模块等。 2.系统仿真结果 1不…

92、动态规划-最小路径和

思路: 还是一样,先使用递归来接,无非是向右和向下,然后得到两种方式进行比较,代码如下: public int minPathSum(int[][] grid) {return calculate(grid, 0, 0);}private int calculate(int[][] grid, int …

吴恩达机器学习笔记:第 9 周-16推荐系统(Recommender Systems) 16.5-16.6

目录 第 9 周 16、 推荐系统(Recommender Systems)16.5 向量化:低秩矩阵分解16.6 推行工作上的细节:均值归一化 第 9 周 16、 推荐系统(Recommender Systems) 16.5 向量化:低秩矩阵分解 在上几节视频中,我们谈到了协同过滤算法&…

如何使用client-go构建pod web shell

代码示例及原理 原理是利用websocket协议实现对pod的exec登录,利用client-go构造与远程apiserver的长连接,将对pod容器的输入和pod容器的输出重定向到我们的io方法中,从而实现浏览器端的虚拟终端的效果消息体结构如下 type Connection stru…

路由策略与路由控制

1.路由控制工具 匹配工具1:访问控制列表 (1)通配符 当进行IP地址匹配的时候,后面会跟着32位掩码位,这32位称为通配符。 通配符,也是点分十进制格式,换算成二进制后,“0”表示“匹配…

element-ui table sortable排序 掉后端接口方式

实例: 官方解释:如果需要后端排序,需将sortable设置为custom,同时在 Table 上监听sort-change事件,在事件回调中可以获取当前排序的字段名和排序顺序,从而向接口请求排序后的表格数据。 1.table上要加 sort-change"sortCha…

15_Scala面向对象编程_访问权限

文章目录 Scala访问权限1.同类中访问2.同包不同类访问3.不同包访问4.子类权限小结 Scala访问权限 知识点概念 private --同类访问private[包名] --包私有; 同类同包下访问protected --同类,或子类 //同包不能访问(default)(public)默认public --公…

Python | Leetcode Python题解之第78题子集

题目: 题解: class Solution:def subsets(self, nums: List[int]) -> List[List[int]]:self.res []self.backtrack([], 0, nums)return self.resdef backtrack(self, sol, index, nums):self.res.append(sol)for i in range(index, len(nums)):self…

物联网实战--平台篇之(四)账户后台交互

目录 一、交互逻辑 二、请求验证码 三、帐号注册 四、帐号/验证码登录 五、重置密码 本项目的交流QQ群:701889554 物联网实战--入门篇https://blog.csdn.net/ypp240124016/category_12609773.html 物联网实战--驱动篇https://blog.csdn.net/ypp240124016/category_12631…

P9422 [蓝桥杯 2023 国 B] 合并数列

P9422 [蓝桥杯 2023 国 B] 合并数列 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 用队列即可 当两个队列队首&#xff1a;a b &#xff0c;弹出 当a < b&#xff0c;把a加给其后一个元素&#xff0c;弹出a 当b < a&#xff0c;把b加给其后一个元素&#xff0c;弹出…

git 配置相关

问题一&#xff1a;ssh-keygen -t ed25519 -C "Gitee SSH Key" 这个命令中的 ed25519 字符是什么意思&#xff1f; ssh-keygen 是一个用于生成SSH密钥的工具&#xff0c;SSH&#xff08;Secure Shell&#xff09;是一种网络协议&#xff0c;用于加密方式远程登录和其…

Docker使用进阶篇

文章目录 1 前言2 使用Docker安装常用镜像示例2.1 Docker安装RabbitMQ2.2 Docker安装Nacos2.3 Docker安装xxl-job&#xff08;推荐该方式构建&#xff09;2.4 Docker安装redis2.5 Docker安装mysql 3 Docker自定义镜像3.1 Dockerfile的基本结构3.2 Dockerfile指令3.3 自定义JDK镜…

免费思维13招之三:赠品型思维

免费思维13招之三:赠品型思维 这节来学习一下免费模式中的三个子思维——赠品型思维、主副型思维和分级型思维。这三个思维有一个共同的名字又叫——产品型思维。 什么是产品型思维?顾名思义,就是在产品上的商业思维。也就是说,通过某一产品的免费来吸引客户,而后进行其…

redis--安装

简介 官网&#xff1a;RedisInsight - The Best Redis GUI 各个版本官网下载地址&#xff1a;http://download.redis.io/releases/ Redis和Memcached是非关系型数据库也称为NoSQL数据库&#xff0c;MySQL、Mariadb、SQL Server、PostgreSQL Oracle 数据库属于关系型数据 应用…

You don’t have permission.

The document “XXX” could not be saved. You don’t have permission. 1.查看修改了iOS系统库导致的, 根据提示, 进入到"XXX"文件中, 然后commandz回退/取消 2. Xcode 调试遇到的报错&#xff08;持续更新&#xff09;

C语言 函数的嵌套与递归 调用

本文 我们来说函数的嵌套调用和递归调用 在很多大型项目中 我们肯定不可能将所有逻辑都写在一个函数中 肯定要按功能拆解成多个特定的功能函数 函数并不允许嵌套调用&#xff0c;但是 允许在逻辑代码中嵌套调用 所谓函数嵌套调用 就是在一个函数中调用另一个函数&#xff0c;而…

计算机毕业设计Python+Spark知识图谱医生推荐系统 医生门诊预测系统 医生数据分析 医生可视化 医疗数据分析 医生爬虫 大数据毕业设计 机器学习

摘 要 随着我国社会经济发展水平的不断提高&#xff0c;人们的物质生活水平也有了很大的改善&#xff0c;越来越多的人不满足于当前的医疗服务质量&#xff0c;由于地域和空间的限制&#xff0c;医疗资源不平衡&#xff0c;无法实现全民共享。针对当今社会中存在的求医难的问题…

2024-05-08 问AI: 在深度学习中,介绍一下RMSProp 优化器

文心一言 在深度学习中&#xff0c;RMSProp&#xff08;Root Mean Square Propagation&#xff09;优化器是一种常用的优化算法&#xff0c;主要用于神经网络训练的梯度下降算法的变体。它是对Adagrad优化器的一种改进&#xff0c;旨在解决Adagrad中学习率过快下降的问题。 R…

Angular中的路由

Angular中的路由 文章目录 Angular中的路由前言一、创建路由二、创建多个组件路由三、创建子路由四、创建多个组件子路由 前言 在Angular中&#xff0c;路由是用于在不同的视图和组件之间导航的机制。Angular提供了一种强大的路由机制来管理单页应用&#xff08;SPA&#xff0…