拼接
维度顺序:对于 3D 张量,通常可以理解为 (深度, 行, 列) 或 (批次, 行, 列)。 选择一个dim进行拼接的时候其他两个维度大小要相等
对于三维张量,理解 torch.cat
的 dim
参数确实变得更加抽象,但原理是相同的。让我们通过一个具体的例子来说明这一点。
import torch# 创建两个 3D 张量
a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
b = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])print("Tensor a shape:", a.shape)
print(a)
print("\nTensor b shape:", b.shape)
print(b)# dim=0 连接
c_dim0 = torch.cat([a, b], dim=0)
print("\nResult of torch.cat([a, b], dim=0):")
print("Shape:", c_dim0.shape)
print(c_dim0)# dim=1 连接
c_dim1 = torch.cat([a, b], dim=1)
print("\nResult of torch.cat([a, b], dim=1):")
print("Shape:", c_dim1.shape)
print(c_dim1)# dim=2 连接
c_dim2 = torch.cat([a, b], dim=2)
print("\nResult of torch.cat([a, b], dim=2):")
print("Shape:", c_dim2.shape)
print(c_dim2)
现在让我们详细解释这个三维张量的例子:
-
初始张量:
a
和b
都是形状为 (2, 2, 2) 的 3D 张量。- 可以将它们想象成两个 2x2 的矩阵堆叠在一起。
-
dim=0
连接:- 结果形状:(4, 2, 2)
- 这相当于在第一个维度上堆叠张量。
- 可以理解为将
b
放在a
的"下面",增加了第一个维度的大小。
-
dim=1
连接:- 结果形状:(2, 4, 2)
- 这相当于在第二个维度上堆叠张量。
- 可以理解为在每个 2x2 矩阵的"行"方向上扩展,将
b
的行添加到a
的每个对应部分的下方。
-
dim=2
连接:- 结果形状:(2, 2, 4)
- 这相当于在第三个维度(最内层)上堆叠张量。
- 可以理解为在每个 2x2 矩阵的"列"方向上扩展,将
b
的列添加到a
的每个对应部分的右侧。
理解三维张量 torch.cat
的关键点:
-
维度顺序:对于 3D 张量,通常可以理解为 (深度, 行, 列) 或 (批次, 行, 列)。
-
dim=0
:增加"深度"或"批次"的数量。 -
dim=1
:增加每个"深度"层或"批次"中的行数。 -
dim=2
:增加每行中的元素数量(列数)。 -
保持其他维度:除了被连接的维度,其他维度的大小保持不变。
-
形状变化:只有指定的
dim
对应的维度大小会改变(增加),其他维度大小保持不变。 -
一致性:要连接的张量在非连接维度上的大小必须相同。
3D Matrix Visualization
Let’s visualize the 3D matrices a and b, and their concatenation results.
Matrix a (2x2x2):
Depth 0: Depth 1:
+---+---+ +---+---+
| 1 | 2 | | 5 | 6 |
+---+---+ +---+---+
| 3 | 4 | | 7 | 8 |
+---+---+ +---+---+
Matrix b (2x2x2):
Depth 0: Depth 1:
+----+----+ +----+----+
| 9 | 10 | | 13 | 14 |
+----+----+ +----+----+
| 11 | 12 | | 15 | 16 |
+----+----+ +----+----+
Concatenation Results:
dim=0 (4x2x2):
Depth 0: Depth 1: Depth 2: Depth 3:
+---+---+ +---+---+ +----+----+ +----+----+
| 1 | 2 | | 5 | 6 | | 9 | 10 | | 13 | 14 |
+---+---+ +---+---+ +----+----+ +----+----+
| 3 | 4 | | 7 | 8 | | 11 | 12 | | 15 | 16 |
+---+---+ +---+---+ +----+----+ +----+----+
dim=1 (2x4x2):
Depth 0: Depth 1:
+---+---+ +---+---+
| 1 | 2 | | 5 | 6 |
+---+---+ +---+---+
| 3 | 4 | | 7 | 8 |
+---+---+ +---+---+
| 9 | 10 | | 13| 14|
+---+---+ +---+---+
| 11| 12 | | 15| 16|
+---+---+ +---+---+
dim=2 (2x2x4):
Depth 0: Depth 1:
+---+---+---+---+ +---+---+---+---+
| 1 | 2 | 9 | 10| | 5 | 6 | 13| 14|
+---+---+---+---+ +---+---+---+---+
| 3 | 4 | 11| 12| | 7 | 8 | 15| 16|
+---+---+---+---+ +---+---+---+---+
当然可以!让我们通过具体的例子来形象地解释不同维度上的拼接。
定义张量
首先,定义三个张量 x
, y
, z
,它们分别具有如下形状:
x
的形状是[2, 1, 3]
y
的形状是[2, 3, 3]
z
的形状是[2, 2, 3]
import torchx = torch.tensor([[[0, 0, 0]], [[0, 0, 0]]])
y = torch.tensor([[[0, 0, 0], [0, 0, 0], [0, 0, 0]],[[0, 0, 0], [0, 0, 0], [0, 0, 0]]
])
z = torch.tensor([[[0, 0, 0], [0, 0, 0]],[[0, 0, 0], [0, 0, 0]]
])
(1) 在 dim=0
上拼接
在 dim=0
上拼接,相当于增加“深度”或“批次”的数量。每个张量的“深度”都会堆叠起来。
w_dim0 = torch.cat([x, y, z], dim=0)
print(w_dim0.shape)
形象解释:
x:
[[[0, 0, 0]], # 第一层深度[[0, 0, 0]] # 第二层深度
]y:
[[[0, 0, 0], [0, 0, 0], [0, 0, 0]], # 第一层深度[[0, 0, 0], [0, 0, 0], [0, 0, 0]] # 第二层深度
]z:
[[[0, 0, 0], [0, 0, 0]], # 第一层深度[[0, 0, 0], [0, 0, 0]] # 第二层深度
]拼接结果 w_dim0:
[[[0, 0, 0]], # x 第一层深度[[0, 0, 0]], # x 第二层深度[[0, 0, 0], [0, 0, 0], [0, 0, 0]], # y 第一层深度[[0, 0, 0], [0, 0, 0], [0, 0, 0]], # y 第二层深度[[0, 0, 0], [0, 0, 0]], # z 第一层深度[[0, 0, 0], [0, 0, 0]] # z 第二层深度
]
形状:[6, 3, 3]
(2)dim=1
上拼接
在 dim=1
上拼接,相当于增加每个“深度”层中的行数。每个深度层的行数会拼接起来。
w_dim1 = torch.cat([x, y, z], dim=1)
print(w_dim1.shape)
形象解释:
x:
[[[0, 0, 0]], # 第一层深度的第一行[[0, 0, 0]] # 第二层深度的第一行
]y:
[[[0, 0, 0], [0, 0, 0], [0, 0, 0]], # 第一层深度的三行[[0, 0, 0], [0, 0, 0], [0, 0, 0]] # 第二层深度的三行
]z:
[[[0, 0, 0], [0, 0, 0]], # 第一层深度的两行[[0, 0, 0], [0, 0, 0]] # 第二层深度的两行
]拼接结果 w_dim1:
[[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], # 第一层深度的六行[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]] # 第二层深度的六行
]
形状:[2, 6, 3]
当然可以!为了展示如何在 dim=2
(第三个维度)上拼接张量,我们需要确保这些张量在前两个维度上的大小是相同的,而在第三个维度上的大小可以不同。
假设我们定义三个张量 a
, b
, c
,它们分别具有如下形状:
a
的形状是[2, 2, 2]
b
的形状是[2, 2, 3]
c
的形状是[2, 2, 1]
import torcha = torch.tensor([[[1, 2], [3, 4]],[[5, 6], [7, 8]]
])b = torch.tensor([[[9, 10, 11], [12, 13, 14]],[[15, 16, 17], [18, 19, 20]]
])c = torch.tensor([[[21], [22]],[[23], [24]]
])
(3)在 dim=2
上拼接
在 dim=2
上拼接,相当于增加每行中的元素数量(列数)。每个深度层中的列数会拼接起来:
w_dim2 = torch.cat([a, b, c], dim=2)
print(w_dim2)
print(w_dim2.shape)
形象解释:
a:
[[[1, 2], [3, 4]], # 第一层深度的两行两列[[5, 6], [7, 8]] # 第二层深度的两行两列
]b:
[[[9, 10, 11], [12, 13, 14]], # 第一层深度的两行三列[[15, 16, 17], [18, 19, 20]] # 第二层深度的两行三列
]c:
[[[21], [22]], # 第一层深度的两行一列[[23], [24]] # 第二层深度的两行一列
]拼接结果 w_dim2:
[[[1, 2, 9, 10, 11, 21], [3, 4, 12, 13, 14, 22]], # 第一层深度的两行六列[[5, 6, 15, 16, 17, 23], [7, 8, 18, 19, 20, 24]] # 第二层深度的两行六列
]w_dim2 的形状为:[2, 2, 6]
通过在 dim=2
上拼接,结果张量 w_dim2
的第三个维度是各个张量第三个维度的和:2 + 3 + 1 = 6
。
# 代码输出:
# tensor([[[ 1, 2, 9, 10, 11, 21],
# [ 3, 4, 12, 13, 14, 22]],
#
# [[ 5, 6, 15, 16, 17, 23],
# [ 7, 8, 18, 19, 20, 24]]])
#
# 形状: torch.Size([2, 2, 6])
希望这个例子能帮助你更好地理解如何在 dim=2
上拼接张量。
非常好的问题!让我们用书架的比喻来解释这个例子,这将有助于更直观地理解张量的维度。
在这个比喻中:
dim=0
(第一个维度)代表书架的数量dim=1
(第二个维度)代表每个书架的层板数dim=2
(第三个维度)代表每个层板可以放置的书本数量(即层板的宽度)
让我们用这个比喻来解释 a
, b
, 和 c
这三个张量:
-
张量
a
[2, 2, 2]:- 2个书架
- 每个书架有2层层板
- 每个层板可以放2本书
-
张量
b
[2, 2, 3]:- 2个书架
- 每个书架有2层层板
- 每个层板可以放3本书
-
张量
c
[2, 2, 1]:- 2个书架
- 每个书架有2层层板
- 每个层板可以放1本书
当我们在 dim=2
上拼接这些张量时,相当于我们在不改变书架数量和层板数量的情况下,将每个层板变宽,使其可以容纳更多的书。
拼接后的结果 w_dim2
[2, 2, 6]:
- 仍然是2个书架(dim=0 没变)
- 每个书架仍然有2层层板(dim=1 没变)
- 但是现在每个层板可以放6本书了(dim=2 变成了 2+3+1=6)
形象地说:
原来的书架 a: 原来的书架 b: 原来的书架 c:
[□□] [□□□] [□]
[□□] [□□□] [□][□□] [□□□] [□]
[□□] [□□□] [□]拼接后的新书架 w_dim2:
[□□□□□□] (2+3+1 = 6本书)
[□□□□□□][□□□□□□]
[□□□□□□]
每个 □ 代表一本书(或者说张量中的一个元素)。
这个比喻展示了我们如何在不增加书架数量(dim=0)或层板数量(dim=1)的情况下,通过拼接来增加每个层板可以放置的书本数量(dim=2)。这就是在 dim=2
上进行张量拼接的直观理解。