神经网络的工程基础(零)——PyTorch基础

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
本文涉及到的代码链接如下:regression2chatgpt/ch06_optimizer/gradient_descent.ipynb

本文将介绍PyTorch的基础。它是神经网络领域常用的建模工具。

关于大语言模型的内容,推荐参考这个专栏。

内容大纲

  • 相关说明
  • 一、PyTorch的数据基础:张量(Tensor)
  • 二、张量的基本计算

一、PyTorch的数据基础:张量(Tensor)

工欲善其事,必先利其器。在讨论如何实现梯度下降法之前,首先探讨一下PyTorch这个强大的工具。PyTorch是一种备受欢迎的开源机器学习框架,被广泛用于构建、训练和部署神经网络模型,因具有灵活性、动态计算图和卓越的GPU支持而成为神经网络领域的首选。

PyTorch的基础数据结构是张量。张量的创建方式如程序清单1所示(完整代码)。

程序清单1 张量的创建
 1 |  # 使用tensor封装的函数创建tensor2 |  zeros = torch.zeros(2, 3)3 |  tensor([[0., 0., 0.],4 |          [0., 0., 0.]])5 |  6 |  ones = torch.ones(2, 3)7 |  tensor([[1., 1., 1.],8 |          [1., 1., 1.]])9 |  
10 |  torch.manual_seed(1024)
11 |  random = torch.rand(3, 4)
12 |  tensor([[0.8090, 0.7935, 0.2099, 0.9279],
13 |          [0.8136, 0.7422, 0.4769, 0.4955],
14 |          [0.3602, 0.1178, 0.7852, 0.0228]])
15 |  
16 |  # 从Python对象创建
17 |  data = [[2, 3, 4], [1, 0, 1]]
18 |  t_data = torch.tensor(data)
19 |  tensor([[2, 3, 4],
20 |          [1, 0, 1]])
21 |  
22 |  ## 从numpy对象创建
23 |  import numpy as np
24 |  
25 |  n_data = np.array(data)
26 |  tn_data = torch.from_numpy(n_data)
27 |  tensor([[2, 3, 4],
28 |          [1, 0, 1]])
29 |  
30 |  ## Numpy bridge,也就是说对numpy对象的改变会传导到tensor
31 |  n_data += 1
32 |  torch.all(torch.from_numpy(n_data) == tn_data)
33 |  tensor(True)

张量的形状(Shape)是至关重要的概念,它定义了张量的维度以及每个维度的大小。在实际应用中,可以通过使用一系列函数来改变张量的形状,使其适应不同的运算需求,如程序清单2所示。

程序清单2 改变张量的形状
 1 |  # 增加或减少数据的维度2 |  a = torch.rand(3, 4)  # (3, 4)3 |  ## 增加维度4 |  b = a.unsqueeze(0)    # (1, 3, 4)5 |  ## 减少维度6 |  c = b.squeeze(0)      # (3, 4)7 |  ## 数据相同,但是维度不同8 |  print(torch.all(c.eq(b)))    # tensor(True)9 |  print(c.shape == b.shape)    # False
10 |  
11 |  # 变换tensor形状
12 |  data = torch.tensor(range(0, 10))   # tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
13 |  view1 = data.view(2, 5)
14 |  tensor([[0, 1, 2, 3, 4],
15 |          [5, 6, 7, 8, 9]])
16 |  transpose1 = view1.T
17 |  tensor([[0, 5],
18 |          [1, 6],
19 |          [2, 7],
20 |          [3, 8],
21 |          [4, 9]])
22 |  ## 非毗邻存储的对象不能进行view操作
23 |  print(view1.is_contiguous(), transpose1.is_contiguous()) 
24 |  True False
25 |  ## 下面的操作会报错
26 |  view2 = transpose1.view(1, 10)
  1. 程序清单2的第4—6行使用unsqueeze和squeeze函数来增加或减少张量的维度。需要注意的是,这些操作并不会改变张量实际存储的数据,也不会在实质上改变张量的形状。相反,它们只是在张量的形状中添加或删除一个空的维度。具体的变化可以在第8行和第9行中看到。
  2. 为了改变张量的形状,可以使用view函数,如第12—15行所示。但需要注意的是,view函数只能用在毗邻存储的张量1对象上。非毗邻存储的张量只能使用reshape函数来改变形状。尽管这两个函数在功能上相似,但在计算效率上存在显著差异:相较于 view 函数,reshape 的计算开销要大得多。因此,在实际应用中,最好优先选择使用 view 函数。

二、张量的基本计算

张量的运算分为两种:逐元素操作(Element-Wise Operations)和矩阵乘法,这些计算方法在处理数据和构建神经网络模型时都具有重要作用。程序清单6-3中讨论了这些操作,并介绍了PyTorch中的广播机制(Broadcasting Semantics),它在处理不同形状的张量时起到了重要的作用。

程序清单3 张量的常见运算
 1 |  # 逐元素操作2 |  twos = torch.ones(2, 2) * 23 |  tensor([[2., 2.],4 |          [2., 2.]])5 |  powers = twos ** torch.tensor([[1, 2], [3, 4]])6 |  tensor([[ 2.,  4.],7 |          [ 8., 16.]])8 |  9 |  ## tensor广播,tensor broadcasting
10 |  a = torch.tensor(range(1, 7)).view(2, 3)
11 |  tensor([[1, 2, 3],
12 |          [4, 5, 6]])
13 |  b = torch.tensor(range(1, 4)).view(   3)
14 |  tensor([1, 2, 3])
15 |  print(a * b)
16 |  tensor([[ 1,  4,  9],
17 |          [ 4, 10, 18]])    
18 |  ## 关于广播,更复杂的例子
19 |  a =     torch.ones(4, 1, 3, 2)
20 |  b = a * torch.rand(   5, 1, 2)
21 |  print(b.shape)
22 |  torch.Size([4, 5, 3, 2])
23 |  
24 |  # 矩阵运算
25 |  mat1 = torch.randn(3, 4)    # (3, 4)
26 |  mat2 = torch.randn(4, 5)    # (4, 5)
27 |  re = mat1 @ mat2            # (3, 5)
28 |  ## 矩阵运算的广播
29 |  mat1 = torch.randn(5, 1, 3, 4)   # (5, 1, 3, 4)
30 |  mat2 = torch.randn(   8, 4, 5)   # (   8, 4, 5)
31 |  re = mat1 @ mat2                 # (5, 8, 3, 5)
  1. 逐元素操作要求进行运算的两个张量的形状必须相同,如程序清单3中的第2—7行所示。然而,在实际应用中,常常需要对形状不同的张量进行操作。为此,PyTorch引入了广播机制,它允许在一定条件下对形状不同的张量进行逐元素操作,如第9—22行所示。
  2. 广播机制的流程相对复杂,如图1所示,需要注意几个关键步骤。首先,从后向前逐个比较两个张量的维度;接着,对缺失的维度进行扩充(类似于unsqueeze函数的操作);然后,检查广播规则,即两个张量的各分量要么相等,要么其中一个等于1;最后,复制数据,实现广播操作。
  3. 广播机制不仅适用于逐元素操作,它同样影响着张量的矩阵乘法。不同之处在于,当执行矩阵乘法时,广播机制只会作用于前面的维度,而不涉及最后两维,如第29—31行所示。

图1

图1


  1. 毗邻存储(C Contiguous)是一个与硬件相关的概念。简而言之,毗邻存储意味着数据在内存中是连续存储的,这种存储方式能够显著提升数据的读取和计算速度。张量在内存中的存储细节超出了本书的范围,对此感兴趣的读者可以在PyTorch的官方文档中找到更详细的信息。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/331088.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

速看!!!24上软考-信息系统项目管理师真题回忆,考点已更新

整理了24上半年软考高级信息系统项目管理师的考试真题,软考一个批次一套题,现在都是机考,收集题目比较困难,希望能给个小小的赞支持一下。 注意:当天考试的宝子们可以对答案预估分数!后面场次的宝子可以提…

WordPress搭建流程

1. 简介 WordPress 是一个 PHP 编写的网站制作平台。WordPress 本身免费,并且拥有众多的主题可以使用,适合用于搭建个人博客、公司官网、独立站等。 2. 环境准备 2.1 WordPress 下载 WordPress 可以在 Worpress中文官网 下载(如果后续要将后台调成中文的话,一定要从中文…

idea中显示git的Local Changes

1. 第一打开idea中的Settings文件 2. 找到Version Contro中的commint 3. 取消勾选应用即可 4. 本地提交就会显示出来

堆和堆排序

目录 1.二叉树的顺序存储2.堆的性质3.堆的实现3.1 堆的插入(向上调整算法)3.2 堆向下调整算法3.3 堆的创建3.4 堆的删除3.5 全套代码 4.堆排序5.Top-K问题 1.二叉树的顺序存储 顺序存储就是数组存储,一般使用数组只适合完全二叉树&#xff0…

AI革命:生活无处不智能

AI革命:生活无处不智能 😄生命不息,写作不止 🔥 继续踏上学习之路,学之分享笔记 👊 总有一天我也能像各位大佬一样 🏆 博客首页 怒放吧德德 To记录领地 🌝分享学习心得&#xff0…

回见,那果园

记不得何时开始骑行,何时开始爬山,何时偶遇洛师傅,何时进了那半山腰的果园。 似乎很远,又很近。 昨天打电话给果园的师傅,本意问问杏是否熟了,周末骑行过去、进山聊天顺道吃个新鲜。 洛师傅呵呵的笑…

电脑版网易云音乐听歌识曲

文章目录 流程 流程 电脑网易云音乐的搜索框旁边就是听歌识曲功能

NDIS小端口驱动开发(一)

在四种NDIS相关的驱动中,微型端口驱动(也经常翻译为为小端口驱动)位于驱动栈的底部,一般将它理解为NIC设备的驱动程序: 有几种类型的微型端口驱动程序类型: 无连接微型端口驱动程序用于控制无连接网络媒体 ,如以太网的…

JMeter 常见易错问题

1、配置错误: 问题:线程组配置错误,例如设置了错误的线程数或循环次数。 解决方法:检查线程组的配置。确保线程数(即并发用户数量)设置正确,以及循环次数符合预期。如果要模拟不同类型的用户行…

arc-eager算法XJTU-NLP自然语言处理技术期末考知识点

arc-eager算法:以我/做了/一个/梦为例来描述arc-eager算法的四个操作:shift,left-arc,right-arc,reduce XJTU-NLP期末考点2024版 题型:5*6简答题4*15计算题 简答题考点: (1&#…

【30天精通Prometheus:一站式监控实战指南】第8天:redis_exporter从入门到实战:安装、配置详解与生产环境搭建指南,超详细

亲爱的读者们👋   欢迎加入【30天精通Prometheus】专栏!📚 在这里,我们将探索Prometheus的强大功能,并将其应用于实际监控中。这个专栏都将为你提供宝贵的实战经验。🚀   Prometheus是云原生和DevOps的…

1301-习题1-1高等数学

1. 求下列函数的自然定义域 自然定义域就是使函数有意义的定义域。 常见自然定义域: 开根号 x \sqrt x x ​: x ≥ 0 x \ge 0 x≥0自变量为分式的分母 1 x \frac{1}{x} x1​: x ≠ 0 x \ne 0 x0三角函数 tan ⁡ x cot ⁡ x \tan x \cot x …

生产物流智能优化系统

对生产调度、物流调度【车辆路径问题、配送中心拣选问题】智能优化算法研究形成系统性程序,逐步开发设计一个智能优化系统【包括:问题说明、实验界面、算法结构和算法程序应用说明】, 当前完成TSP和集送车辆路径的算法程序,程序效…

Pandas高效数据清洗与转换技巧指南【数据预处理】

三、数据处理 1.合并数据(join、merge、concat函数,append函数) Concat()函数使用 1.concat操作可以将两个pandas表在垂直方向上进行粘合或者堆叠。 join属性为outer,或默认时,返回列名并集,如&#xff…

【大数据】MapReduce JAVA API编程实践及适用场景介绍

目录 1.前言 2.mapreduce编程示例 3.MapReduce适用场景 1.前言 本文是作者大数据系列专栏的其中一篇,前文我们依次聊了大数据的概论、分布式文件系统、分布式数据库、以及计算引擎mapreduce核心概念以及工作原理。 书接上文,本文将会继续聊一下mapr…

K8S认证|CKA题库+答案| 17. 节点维护

17、节点维护 CKA v1.29.0模拟系统免费下载试用: 百度网盘:https://pan.baidu.com/s/1vVR_AK6MVK2Jrz0n0R2GoQ?pwdwbki 题目: 您必须在以下Cluster/Node上完成此考题: Cluster Ma…

无线领夹麦克风哪个品牌好?无线麦克风品牌排行榜前十名推荐

​在当今的数字化浪潮中,个人声音的传播和记录变得尤为重要。无论是会议中心、教室讲台还是户外探险,无线领夹麦克风以其卓越的便携性和连接稳定性,成为了人们沟通和表达的首选工具。面对市场上琳琅满目的无线麦克风选择,为了帮助…

Arduino下载与安装(Windows 10)

Arduino下载与安装(Windows 10) 官网 下载安装 打开官网,点击SOFTWARE,进入到软件下载界面,选择Windows 选择JUST DOWNLOAD 在弹出的界面中,填入电子邮件地址,勾选Privacy Policy,点击JUST DOWNLOAD即可 …

使用SDL_QT直接播放渲染YUV格式文件

0.前要 下载一个文件,名字为 400_300_25.mp4,我们用ffmplay.exe将其转化为yuv文件,具体操作如下: 进入cmd控制台,进入ffmplay.exe文件的目录下,输入ffmpeg -i 文件名.mp4 文件名.yuv 回车,会生…

Java进阶学习笔记15——接口概述

认识接口: Java提供了一个关键字Interface,用这个关键字我们可以定义一个特殊的结构:接口。 接口不能创建对象。 注意:接口不能创建对象,接口是用来被类实现(implements)的,实现接口…