1. torch.unbind 作用
-
说明:
移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片
。 -
参数:
- tensor(Tensor) – 输入张量
- dim(int) – 删除的维度
2. 案例
案例1
x = torch.rand(1,80,3,360,360)y= x.unbind(dim=2)print("y0 shape",y[0].shape)print("y1 shape",y[1].shape)print("y2 shape",y[2].shape)
- 将
shape大小
为(1,80,2,360,360)的x
,沿着dim为2
的维度切片。 - 此时会移除
dim
为2
的维度,得到由3
个 元素大小为(1,80,360,360)
的tensor组成的元组。 - 元组中tensor个数,和指定切片对应维度的值相等。
x = torch.rand(1,80,3,360,360)a =torch.cat(x.unbind(dim=2),1)a.shape
a.shape: torch.Size([1,240,360,360])
- 切片后得到包含3个Tensor的元组,每个tensor大小为
(1,80,360,360)
- 3个tensor沿dim为1进行concate, 因此得到tensor大小为
(1,240,360,360)