在指定的位置插入一个维度,有两个参数,input是输入的tensor,dim是要插到的维度
需要注意的是dim的范围是[-input.dim()-1, input.dim()+1),是一个左闭右开的区间,当dim为负值时,会自动转换为dim = dim+input.dim()+1,类似于使用负数对python列表进行切片。
tensor([[-0.4734, 0.4115, -0.9415, -1.1280, -0.1065],
[ 0.1613, 1.2594, 1.1261, 1.3881, 0.1112]])
以上是二维数据情况:
首先生成了一个二维矩阵,其大小为[2,5]
然后,在0维度上插入一个维度,可以看到现在新矩阵a的形状变为[1,2,5],第0维度的大小默认是1
最后,在最后一个维度上插入一个维度,形状变为[2, 5, 1]
print("torch.unsqueeze(a,3) size: {}".format(torch.unsqueeze(a,3).size()))
print("torch.unsqueeze(a,2) size: {}".format(torch.unsqueeze(a,2).size()))
print("torch.unsqueeze(a,1) size: {}".format(torch.unsqueeze(a,1).size()))
print("torch.unsqueeze(a,0) size: {}".format(torch.unsqueeze(a,0).size()))
print("torch.unsqueeze(a,-1) size: {}".format(torch.unsqueeze(a,-1).size()))
print("torch.unsqueeze(a,-2) size: {}".format(torch.unsqueeze(a,-2).size()))
print("torch.unsqueeze(a,-3) size: {}".format(torch.unsqueeze(a,-3).size()))
print("torch.unsqueeze(a,-4) size: {}".format(torch.unsqueeze(a,-4).size()))
torch.unsqueeze(a,3) size: torch.Size([2, 3, 2, 1])
torch.unsqueeze(a,2) size: torch.Size([2, 3, 1, 2])
torch.unsqueeze(a,1) size: torch.Size([2, 1, 3, 2])
torch.unsqueeze(a,0) size: torch.Size([1, 2, 3, 2])
torch.unsqueeze(a,-1) size: torch.Size([2, 3, 2, 1])
torch.unsqueeze(a,-2) size: torch.Size([2, 3, 1, 2])
torch.unsqueeze(a,-3) size: torch.Size([2, 1, 3, 2])
torch.unsqueeze(a,-4) size: torch.Size([1, 2, 3, 2])

对于三维数据input.dim() = 3,因此dim的范围是[-4, 4)