【PyTorch】3.张量类型转换

个人主页:Icomi

在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架,为构建和训练神经网络提供了高效且灵活的平台。神经网络作为人工智能的核心技术,能够处理复杂的数据模式。通过 PyTorch,我们可以轻松搭建各类神经网络模型,实现从基础到高级的人工智能应用。接下来,就让我们一同走进 PyTorch 的世界,探索神经网络与人工智能的奥秘。本系列为PyTorch入门文章,若各位大佬想持续跟进,欢迎与我交流互关。

大家好我是一颗米,我们已经了解了张量在 PyTorch 中的核心地位,也知道了它在 CPU 和 GPU 上的运算方式,这些都是搭建我们深度学习知识大厦的重要基石。但这座大厦要盖得又高又稳,还需要更多的 “砖块”,接下来我们就来学习其中非常关键的一块 —— 张量的类型转换。

在实际的深度学习项目中,我们会从各种不同的数据源获取数据,这些数据可能最初是以不同的形式存在的。而我们之前学过,在 PyTorch 里计算数据基本都是以张量形式,所以就经常需要进行数据类型的转换。其中,将 numpy 数组和 PyTorch Tensor 相互转化,就是最常使用的一种操作,这也是大家必须掌握的知识点。

numpy 在 Python 的数据处理领域应用非常广泛,很多经典的数据集和算法库都与 numpy 紧密相关。当我们从这些地方获取数据后,往往就需要把 numpy 数组转化为 PyTorch Tensor,才能在 PyTorch 的深度学习模型中进行运算。反之,当我们在 PyTorch 模型中完成某些计算,需要使用一些 numpy 强大的数据分析和处理工具时,又得把 Tensor 转换回 numpy 数组。

这一节,我们主要就来学习如何在 numpy 数组和 PyTorch Tensor 之间自由 “穿梭”,掌握这一关键技能,为我们后续的深度学习扫除障碍。

1. 张量转换为 numpy 数组

使用 Tensor.numpy 函数可以将张量转换为 ndarray 数组,但是共享内存,可以使用 copy 函数避免共享。

# 1. 将张量转换为 numpy 数组
def tensor_to_numpy():# 创建一个 PyTorch 张量tensor = torch.tensor([2, 3, 4])# 使用张量对象中的 numpy 函数进行转换numpy_array = tensor.numpy()print(type(tensor))print(type(numpy_array))# 注意: tensor 和 numpy_array 共享内存# 修改其中的一个,另外一个也会发生改变# tensor[0] = 100numpy_array[0] = 100print(tensor)print(numpy_array)# 2. 对象拷贝避免共享内存
def tensor_to_numpy_with_copy():# 创建一个 PyTorch 张量tensor = torch.tensor([2, 3, 4])# 使用张量对象中的 numpy 函数进行转换,先克隆张量避免共享内存numpy_array = tensor.clone().numpy()print(type(tensor))print(type(numpy_array))# 修改 numpy 数组,不会影响原张量numpy_array[0] = 100print(tensor)print(numpy_array)if __name__ == "__main__":tensor_to_numpy()tensor_to_numpy_with_copy()

2. numpy 转换为张量

  1. 使用 from_numpy 可以将 ndarray 数组转换为 Tensor,默认共享内存,使用 copy 函数避免共享。
  2. 使用 torch.tensor 可以将 ndarray 数组转换为 Tensor,默认不共享内存。
# 1. 使用 from_numpy 函数
def test01():data_numpy = np.array([2, 3, 4])# 将 numpy 数组转换为张量类型# 1. from_numpy# 2. torch.tensor(ndarray)# 浅拷贝data_tensor = torch.from_numpy(data_numpy)# nunpy 和 tensor 共享内存# data_numpy[0] = 100data_tensor[0] = 100print(data_tensor)print(data_numpy)# 2. 使用 torch.tensor 函数
def test02():data_numpy = np.array([2, 3, 4])data_tensor = torch.tensor(data_numpy)# nunpy 和 tensor 不共享内存# data_numpy[0] = 100data_tensor[0] = 100print(data_tensor)print(data_numpy)

3. 标量张量和数字的转换

对于只有一个元素的张量,使用 item 方法将该值从张量中提取出来。

# 3. 标量张量和数字的转换
def test03():# 当张量只包含一个元素时, 可以通过 item 函数提取出该值data = torch.tensor([30,])print(data.item())data = torch.tensor(30)print(data.item())if __name__ == '__main__':test03()

4.总结

本节内容比较简单, 我们主要学习了 numpy 和 tensor 互相转换的规则, 以及标量张量与数值之间的转换规则。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/7596.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机位:解锁摄影视角的多维度密码

目录 一、机位的构成要素 (一)高度维度 (二)角度维度 (三)距离维度 二、移动机位的魅力 (一)推镜头 (二)拉镜头 (三)摇镜头 …

【例51.3】 平移数据

题目描述 将a数组中第一个元素移到数组末尾,其余数据依次往前平移一个位置。 输入 第一行为数组a的元素个数; 第二行为n个小于1000的正整数。 输出 平移后的数组元素,每个数用一个空格隔开。 样例输入 复制 10 1 2 3 4 5 6 7 8 9 10 样例输出 复…

【Project】CupFox电影网站数据爬取分析与可视化

数据采集清洗与数据存储流程如下图所示。 数据分析与数据可视化流程设计如下 1.使用pymongo从数据库中查询所需的数据。对数据进行处理和分析,进行统计、分类、聚合等操作,提取关键指标和洞察。分析结果可以通过编写Python代码进一步优化、筛选和整理&a…

gradle创建springboot单项目和多模块项目

文章目录 gradle创建springboot项目gradle多模块项目创建 gradle创建springboot项目 适用IDEA很简单,如下图 gradle多模块项目创建 首选创建父项目,然后删除无用内容至下图 选择父项目目录,右键选择模块,创建子项目&#xff08…

数据库的JOIN连接查询算法

文章目录 3.2 Join 算法优化3.1.2 Nested Loop Join(NLJ)3.1.3 Block Nested Loop Join(BNLJ)3.1.4 Index Nested Loop Join(INLJ)3.1.5 Sort Merge Join(SMJ)3.1.6 Hash Join 3.2 J…

Golang Gin系列-8:单元测试与调试技术

在本章中,我们将探讨如何为Gin应用程序编写单元测试,使用有效的调试技术,以及优化性能。这包括设置测试环境、为处理程序和中间件编写测试、使用日志记录、使用调试工具以及分析应用程序以提高性能。 为Gin应用程序编写单元测试 设置测试环境…

二叉树的最大深度(C语言详解版)

一、摘要 嗨喽呀大家,leetcode每日一题又和大家见面啦,今天要讲的是104.二叉树的最大深度,思路互相学习,有什么不足的地方欢迎指正!好啦让我们开始吧!!! 二、题目简介 给定一个二…

开发环境搭建-3:配置 nodejs 开发环境 (fnm+ node + pnpm)

在 WSL 环境中配置:WSL2 (2.3.26.0) Oracle Linux 8.7 官方镜像 node 官网:https://nodejs.org/zh-cn/download 点击【下载】,选择想要的 node 版本、操作系统、node 版本管理器、npm包管理器 根据下面代码提示依次执行对应代码即可 基本概…

HTB:Support[WriteUP]

目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 将靶机TCP开放端口号提取并保存 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用ldapsearch…

洛谷P1017 [NOIP2000 提高组] 进制转换

题目链接:P1017 [NOIP2000 提高组] 进制转换 - 洛谷 | 计算机科学教育新生态 题目难度:普及一 题目分析:这是道数学题,我们都知道,首先按照10进制转成n进制的做法:对这个数不断除以n,将余数一一…

php代码审计2 piwigo CMS in_array()函数漏洞

php代码审计2 piwigo CMS in_array()函数漏洞 一、目的 本次学习目的是了解in_array()函数和对项目piwigo中关于in_array()函数存在漏洞的一个审计并利用漏洞获得管理员帐号。 二、in_array函数学习 in_array() 函数搜索数组中是否存在指定的值。 in_array($search,$array…

【2024年华为OD机试】(A卷,200分)- 查找树中元素 (JavaScriptJava PythonC/C++)

一、问题描述 题目解析 题目描述 题目要求根据输入的坐标 (x, y) 在树形结构中找到对应节点的内容值。其中: x 表示节点所在的层数,根节点位于第0层,根节点的子节点位于第1层,依此类推。y 表示节点在该层内的相对偏移,从左至右,第一个节点偏移为0,第二个节点偏移为1,…

WPS数据分析000006

一、排序 开始→ 排序 同文件→选项→自定义序列→输入序列 二、筛选 高级筛选 条件区域要与列表区域一样。 三、条件格式

基于微信小程序的英语学习交流平台设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

记录一个连不上docker中的mysql的问题

引言 使用的debian12,不同发行版可能有些许差异,连接使用的工具是navicat lite 本来是毫无思绪的,以前在云服务器上可能是防火墙的问题,但是这个桌面环境我压根没有使用防火墙。 直到 ying192:~$ mysql -h127.0.0.1 -uroot ERROR 1045 (28…

SpringBoot基础概念介绍-数据源与数据库连接池

🙋大家好!我是毛毛张! 🌈个人首页: 神马都会亿点点的毛毛张 毛毛张今天介绍的SpringBoot中的基础概念-数据源与数据库连接池,同时介绍SpringBoot整合两种连接池的教程 文章目录 1 数据库与数据库管理系统2 JDBC与数…

深度学习 Pytorch 单层神经网络

神经网络是模仿人类大脑结构所构建的算法,在人脑里,我们有轴突连接神经元,在算法中,我们用圆表示神经元,用线表示神经元之间的连接,数据从神经网络的左侧输入,让神经元处理之后,从右…

GCC之编译(8)AR打包命令

GCC之(8)AR二进制打包命令 Author: Once Day Date: 2025年1月23日 一位热衷于Linux学习和开发的菜鸟,试图谱写一场冒险之旅,也许终点只是一场白日梦… 漫漫长路,有人对你微笑过嘛… 全系列文章请查看专栏: Linux实践记录_Once-Day的博客-C…

SpringBoot统一功能处理

一.拦截器 1.拦截器的简单介绍 拦截器是Spring框架提供的核⼼功能之⼀,主要⽤来拦截⽤⼾的请求,在指定⽅法前后,根据业务需要执⾏预先设定的代码. 2.使用 (i).定义拦截器: (ii).注册拦截器 (iii).拦截路径 (iv).实行流程 3.登录校验 4.DispatcherServlet源码&…

31、Java集合概述

目录 一.Collection 二.Map 三.Collection和Map的区别 四.应用场景 集合是一组对象的集合,它封装了对象的存储和操作方式。集合框架提供了一组接口和类,用于存储、访问和操作这些对象集合。这些接口和类定义了不同的数据结构,如列表、集合…