20250110_ PyTorch中的张量操作

文章目录

  • 前言
  • 1、torch.cat 函数
  • 2、索引、维度扩展和张量的广播
  • 3、切片操作
    • 3.1、 encoded_first_node
    • 3.2、probs
  • 4、长难代码分析
    • 4.1、selected
      • 4.1.1、multinomial(1)工作原理:
  • 总结


前言


1、torch.cat 函数

torch.cat 函数将两个张量拼接起来,具体地是在第三个维度(dim=2)上进行拼接。注:dim取值范围是0~2

node_xy_demand = torch.cat((node_xy, node_demand[:, :, None]), dim=2)

其中所用参数为:

node_xy = reset_state.node_xy
# shape: (batch, problem, 2)
node_demand = reset_state.node_demand
# shape: (batch, problem)

若要拼接node_xy 与node_demand 需要将node_demand 进行维度拓展node_demand[:, :, None])

node_xy = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
node_demand = torch.tensor([[[10], [20]], [[30], [40]]])
node_xy_demand = torch.tensor([[[ 1,  2, 10], [ 3,  4, 20]],[[ 5,  6, 30], [ 7,  8, 40]]])

2、索引、维度扩展和张量的广播

_ = self.decoder.regret_embedding[None, None, :].expand(encoded_nodes.size(0), 1, self.decoder.regret_embedding.size(-1))
  • self.decoder.regret_embedding是一个张量。
  • self.decoder.regret_embedding[None, None, :]增加regret_embedding的维度。维度扩展成 (1, 1, D)
.expand(encoded_nodes.size(0), 1, self.decoder.regret_embedding.size(-1))
  • expand 用来沿特定维度复制张量,以实现广播。
  • encoded_nodes.size(0) 返回的是 encoded_nodes 张量的第一个维度大小。
  • 1 表示第二个维度的大小。
  • self.decoder.regret_embedding.size(-1) 返回的是 self.decoder.regret_embedding 的最后一个维度的大小,也就是嵌入的维度 D

总结: 将张量建立为所需维度在此为三维,使用expand沿着新建维度进行拓展到所需形状


3、切片操作

3.1、 encoded_first_node

 encoded_first_node = self.encoded_nodes[:, [0], :]

这行代码中的切片操作是从 self.encoded_nodes 中提取特定的数据部分:

  • : 表示选择所有批次的样本,保留第一个维度(batch)。
  • [0] 表示选择每个样本中的第一个节点,因此提取的是第一个节点的嵌入向量。
  • : 表示选择该节点的所有嵌入维度,即保留第三个维度(embedding)的所有值。

最终,经过这些操作,encoded_first_node 的形状为 (batch, 1, embedding),即每个样本只包含第一个节点的嵌入向量,保留了嵌入维度。

3.2、probs

probs[:, :, :-1]
  • 这是对 probs 张量的切片操作,作用是从 probs 的第三个维度(即最后一个维度)中移除最后一列。
selected = probs.argmax(dim=2)
  • argmax(dim=2) 表示在 probs 张量的第3维度(类别维度)上,找到每个样本中概率最大的类别索引。

  • argmax 返回的是最大值的索引,而不是最大值本身。


4、长难代码分析

4.1、selected

selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1).squeeze(dim=1).reshape(batch_size, pomo_size)

prob的shape: (batch, pomo, problem+1)

  • probs.reshape(batch_size * pomo_size, -1)

    • 这一步将 probs 的形状从 (batch, pomo, problem + 1) 转变为 (batch * pomo, problem + 1)。
    • -1:表示自动推算出第二维的大小(即 problem + 1)
    • 新的形状 (batch * pomo, problem + 1)。
  • multinomial(1)

    • multinomial(1) 用于从给定的概率分布中选择一个类别。它会返回一个形状为 (batch_size * pomo_size, 1) 的张量,每一行选择一个元素的索引,代表从 probs 中选择的元素。
  • .squeeze(dim=1)

    • squeeze(dim=1) 是去除第二个维度(索引维度),将形状变为 (batch_size * pomo_size)
  • .reshape(batch_size, pomo_size)

    • 最后,通过 reshape(batch_size, pomo_size) 将张量恢复到原来的形状 (batch_size, pomo_size),即每个批次对应一个选择的元素索引。

4.1.1、multinomial(1)工作原理:

  • 输入:
    multinomial(1) 需要一个形状为 (N, C) 的张量,其中 N 是样本的数量,C 是类别的数量。这个张量表示每个样本在各个类别下的概率分布。

  • 输出:
    multinomial(1) 返回一个形状为 (N, 1) 的张量,每个元素是该样本选择的类别的索引。

具体来说,multinomial(1) 会根据每个类别的概率,从概率分布中选取一个类别。这个选择是随机的,但是会遵循给定的概率分布,即概率较大的类别被选中的几率较高,概率较小的类别被选中的几率较低。


总结

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/505785.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Springboot+Vue的仓库管理系统

开发一个基于Spring Boot和Vue的仓库管理系统涉及到前端和后端的开发。本文呢,给出一个简单的开发步骤指南,用于指导初入的新手小白如何开始构建这样一个系统,如果**你想直接学习全部内容,可以直接拉到文末哦。** 开始之前呢给小…

快速导入请求到postman

1.确定请求,右键复制为cURL(bash) 2.postman菜单栏Import-Raw text,粘贴复制的内容保存,请求添加成功

预训练语言模型——BERT

1.预训练思想 有了预训练就相当于模型在培养大学生做任务,不然模型初始化再做任务就像培养小学生 当前数据层面的瓶颈是能用于预训练的语料快被用完了 现在有一个重要方向是让机器自己来生成数据并做微调 1.1 预训练(Pre - training)vs. 传…

数字孪生电网有什么作用?实时云渲染技术又如何赋能智慧电网?

电网系统的结构比较复杂,传统运维模式主要是依赖传感器和人工巡检,难以全面监测管理。而数字孪生技术的应用将推动电网智能化、绿色化的高效转型。 智慧电网利用物理模型、现场测量数据和历史数据,结合云计算、物联网、大数据等技术&#xf…

MiniMind - 从0训练语言模型

文章目录 一、关于 MiniMind 📌项目包含 二、📌 Environment三、📌 Quick Start Test四、📌 Quick Start Train0、克隆项目代码1、环境安装2、如果你需要自己训练3、测试模型推理效果 五、📌 Data sources1、分词器&am…

EasyCVR视频汇聚平台如何配置webrtc播放地址?

EasyCVR安防监控视频系统采用先进的网络传输技术,支持高清视频的接入和传输,能够满足大规模、高并发的远程监控需求。平台支持多协议接入,能将接入到视频流转码为多格式进行分发,包括RTMP、RTSP、HTTP-FLV、WebSocket-FLV、HLS、W…

【GlobalMapper精品教程】093:将tif影像色彩映射表(调色板)转为RGB全彩模式

参考阅读:【ArcGIS微课1000例】0137:色彩映射表转为RGB全彩模式 文章目录 一、Globalmapper中显示模式二、ArcGIS中显示模式三、调色板转为RGB全彩模式四、注意事项一、Globalmapper中显示模式 Globalmapper中,将谷歌等多种来源在线影像下载到本地后,可能会遇到以下数据格…

Postman接口测试05|实战项目笔记

目录 一、项目接口概况 二、单接口测试-登录接口:POST 1、正例 2、反例 ①姓名未注册 ②密码错误 ③姓名为空 ④多参 ⑤少参 ⑥无参 三、批量运行测试用例 四、生成测试报告 1、Postman界面生成 2、Newman命令行生成 五、token鉴权(“…

【css】浏览器强制设置元素状态(hover|focus……)

直接上步骤: 打开浏览器控制台 → 找到样式选项 → 找到:hov选项 → 点击:hov选项,会展开【设置元素状态】。 只要选中就会展示出自己写在css里面的该种状态下的样式了。

Springboot——钉钉(站内)实现登录第三方应用

文章目录 前言准备1、创建钉钉应用,并开放网页应用2、配置网页应用各项参数发布版本 前端改造后端逻辑1、获取应用免登录 Access_token2、通过免登录 Access_token 和 Auth_Code 获取对应登录人信息 注意事项 前言 PC端的钉钉中工作台,增加第三方应用&a…

完美解决VMware 17.0 Pro安装ubuntu、Deepin等虚拟机后卡顿、卡死问题

这两天在 VM 17 Pro 中安装了ubuntu 24.1 和Deepin 23.9 等Linux操作系统,在使用过程中出现过数次卡顿、卡死问题,现记录整理解决方法如下: 一、问题描述 安装虚拟机时、以及安装完成后正常使用时出现鼠标点击卡顿、系统反应慢、卡死等问题…

计算机的错误计算(二百零七)

摘要 利用两个数学大模型计算 arccot(0.125664e2)的值,结果保留16位有效数字。 实验表明,它们的输出中分别仅含有3位和1位正确数字。 例1. 计算 arccot(0.125664e2)的值,结果保留16位有效数字。 下面是与一个数学解题器的对话。 以上为与…

Linux内核TTY子系统有什么(6)

接前一篇文章:Linux内核TTY子系统有什么(5) 本文内容参考: Linux TTY子系统框架-CSDN博客 一文彻底讲清Linux tty子系统架构及编程实例-CSDN博客 linux TTY子系统(3) - tty driver_sys tty device driver-CSDN博客 Linux TTY …

03_Redis基本操作

1.Redis查询命令 1.1 官网命查询命令 为了便于学习Redis,官方将其用于操作不同数据类型的命令进行了分类整理。你可以通过访问Redis官方网站上的命令参考页面https://redis.io/commands来查阅这些分组的命令,这有助于更系统地理解和使用Redis的各项功能。 1.2 HELP查询命令…

@LocalBuilder装饰器: 维持组件父子关系

一、前言 当开发者使用Builder做引用数据传递时,会考虑组件的父子关系,使用了bind(this)之后,组件的父子关系和状态管理的父子关系并不一致。为了解决组件的父子关系和状态管理的父子关系保持一致的问题,引入LocalBuilder装饰器。…

kubernetes第七天

1.影响pod调度的因素 nodeName 节点名 resources 资源限制 hostNetwork 宿主机网络 污点 污点容忍 Pod亲和性 Pod反亲和性 节点亲和性 2.污点 通常是作用于worker节点上,其可以影响pod的调度 语法:key[value]:effect effect:[ɪˈfek…

FFmpeg Muxer HLS

使用FFmpeg命令来研究它对HLS协议的支持程度是最好的方法: ffmpeg -h muxerhls Muxer HLS Muxer hls [Apple HTTP Live Streaming]:Common extensions: m3u8.Default video codec: h264.Default audio codec: aac.Default subtitle codec: webvtt. 这里面告诉我…

maven高级(day15)

Maven 是一款构建和管理 Java 项目的工具 分模块设计与开发 所谓分模块设计,顾名思义指的就是我们在设计一个 Java 项目的时候,将一个 Java 项目拆分成多 个模块进行开发。 分模块设计我们在进行项目设计阶段,就可以将一个大的项目拆分成若干…

【json】

JSON JSON是一种轻量级的,按照指定的格式去组织和封装数据的数据交互格式。 本质上是一个带有特定格式的字符串(py打印json时认定为str类型) 在各个编程语言中流通的数据格式,负责不同编程语言中的数据传递和交互,类似于计算机普通话 python与json关系及相互转换…

计算机网络 笔记 数据链路层 2

1,信道划分: (1)时分复用TDM 将时间等分为“TDM帧”,每个TDM帧内部等分为m个时隙,m个用户对应m个时隙 缺点:每个节点只分到了总带宽的1/m,如果有部分的1节点不发出数据,那么就会在这个时间信道被闲置,利用…