本文将从两个角度来理解 “轴” 的概念,着重阐述 1.2 节中的理解,并借此加深问题一和问题二的理解。
一、问题:如何理解 numpy 数组在轴上的 sum 操作
二、问题:torch 张量中的维度 dim 也是如此
一、问题:如何理解 numpy 数组在轴上的 sum 操作
a = np.ones((2, 3, 4), dtype=int)
a.sum(axis=0)
a.sum(axis=1)
a.sum(axis=2)
注:以下与 axis 有关的索引或计数,都从 0 开始,而不是从 1 开始。
1.1 在 axis 上的操作可理解为在该轴上对所有切片的操作
a.sum(axis=0)
等价于 a[0, :, :] + a[1, :, :]
- 表示在第 0 轴上所有切片的求和。
- 数组 a 在第 0 轴上的长度是 2,即 np.ones(2, 3, 4) 中的 2,所以所以 sum 操作有两个切片求和。
a.sum(axis=1)
等价于 a[:, 0, :] + a[:, 1, :] + a[:, 2, :]
- 表示在第 1 轴上所有切片的求和。
- 数组 a 在第 1 轴上的长度是 3,所以 sum 操作有三个切片进行求和。
a.sum(axis=2)
等价于 a[:, :, 0] + a[:, :, 1] + a[:, :, 2] + a[:, :, 3]
- 表示在第 2 轴上所有切片的求和。
- 数组 a 在第 2 轴上的长度是 4,所以 sum 操作有四个切片进行求和。
1.2 axis 代表 “方括号” 的位置
为了理解 aixs 的意义,我们作如下约定:
① 第零层方括号指下图中最外层的黄色方括号 。对于任何一个 n 维数组来说,最外层的方括号只有一对,因为这是一个 n 维数组而不是两个。
② 第一层方括号指下图中的紫色方括号,共 2 对紫色方括号,即两个第一层。
③ 第二层方括号指下图中的蓝色方括号,在每对紫色内有 3 对蓝色方括号,即三个第二层。
④ 第三层方括号指下图中各绿色的数字元素,在每对蓝色内有 4 个绿色数字
⑤ 那么 a = np.ones(2, 3, 4)
中的 2、3、4 即对应上面三句话。
⑥ 如果相对黄色方括号而言,我们也称成对的紫色方括号为一个元素,即黄色方括号内有两个紫色元素,依此类推。
(1)那么, 对于 sum(axis=0)
,axis 0 指 np.ones(2, 3, 4) 中第零个位置的 2,表示第零层内有两个紫色元素。就是对两组紫色括号进行聚合操作,即先把两组紫色括号看成 sum 操作的两个输入,再对应各数字元素相加,结果如下图:
(2) 对于 sum(axis=1)
,axis 1 指 np.ones(2, 3, 4) 中第一个位置(从零开始数)的 3,表示第一层内有三个蓝色元素。就是对三组蓝色括号(暂时只看第一组紫色内的三组蓝色,其他紫色组内同理)进行聚合操作,即把三组紫色括号看成 sum 操作的三个输入,再对应数字相加,得到 [3,3,3,3]。但是有两组紫色,每组紫色内的蓝色都要各自聚合,所以最后的结果是两组 [3,3,3,3],如下图:
(3)对于 sum(axis=2)
,axis 2 指 np.ones(2, 3, 4) 中第二个位置(从零开始数)的 4,表示第二层内有四个绿色元素。就是把各组蓝色内的数字先聚合,即1+1+1+1=4;其他组蓝色中的绿色数字都要按此运算,一共有6组蓝色,所以有6个4。
二、torch 张量中的维度 dim 也是如此
先定义了两个 shape 均为 (2,3,4) 的张量 X,Y,如下所示:
torch.cat()
是将两个张量拼接为一个,其中参数 dim 指定拼接的维度。
(1)当 dim = 0,即将 X 的第零层内的两组方括号与 Y 的第零层内的两组方括号直接拼接。
(2)当 dim = 1,即将 X 的第一层内的三组方括号与 Y 的第一层内的三组方括号直接拼接,融成新的第一层。由于第零层内有两个元素,所以有两个第一层,另外一个第一层也是同理拼接。
(3)当 dim = 2,即将 X 的第二层内的四个数字与 Y 的第二层内的四个数字直接拼接,融成新的第二层。
三种情况的拼接结果如下图所示: