Pytorch:cat、stack、squeeze、unsqueeze的用法
torch.cat
在指定原有维度上链接传入的张量,所有传入的张量都必须是相同形状
torch.cat(tensors, dim=0, *, out=None) → Tensor
tensor:相同形状的tensor
dim:链接张量的维度,不能超过传入张量的维度
x = torch.tensor([[0, 1, 2]], dtype= torch.float)
y = torch.tensor([[3, 4, 5]], dtype= torch.int)
print(x.shape, y.shape)
print("-"*50)
z = torch.cat((x, y), dim= 0)
print(z)
print(z.shape)
print("-"*50)
z = torch.cat((x, y), dim= 1)
print(z)
print(z.shape)
torch.stack
在一个新的维度上链接张量,输入张量都必须是相同形状的
torch.stack(tensors, dim=0, *, out=None) → Tensor
tensor:相同形状的张量
dim:插入的张量维度,在0和输出张量维度(比输入张量维度多一个)之间
x = torch.tensor([[0, 1, 2]])
y = torch.tensor([[3, 4, 5]])
print(x.shape, y.shape)
print("-"*50)
z = torch.stack((x, y), dim= 0)
print(z)
print(z.shape)
print("-"*50)
z = torch.stack((x, y), dim= 1)
print(z)
print(z.shape)
print("-"*50)
z = torch.stack((x, y), dim= 2)
print(z)
print(z.shape)
torch.squeeze
压缩张量,去掉输入张量中大小为1的维度,例如:(Ax1xBxCx1)->(AxBxC)
torch.squeeze(input, dim=None) → Tensor
input (Tensor):输入张量
dim (int or tuple of ints, optional):只压缩某个维度,可以不指定,就是压缩所有大小为1的维度
x = torch.tensor([[0, 1, 2]])
y = torch.rand(size= (1, 2, 1, 2, 1))
print(x.shape, y.shape)
print("-"*50)
z = torch.squeeze(x)
print(z)
print(z.shape)
print("-"*50)
z = torch.squeeze(y)
print(z)
print(z.shape)
torch.unsqueeze
在输入张量中指定位置插入一个大小为1的维度
torch.unsqueeze(input, dim) → Tensor
input (Tensor):输入张量
dim (int):插入维度的指定位置
x = torch.randn(size= (2,3))
print(x.shape)
print("-"*50)
z = torch.unsqueeze(x, 0)
print(z)
print(z.shape)
print("-"*50)
z = torch.unsqueeze(x, 1)
print(z)
print(z.shape)