张量(Tensor),就是多维数组。当维度小于或等于2时,张量又有一些更熟悉的名字:
聚合操作(Aggregation):常见的张量聚合运算包括:求平均、求和、最大值、最小值等。
张量的聚合运算关于dim的规则:
在做张量的运算操作时,dim设定了哪个维,就会遍历这个维去做运算(也称“沿着该维运算”),其他为顺序不变。
keepdim参数,默认设置为False,需要显示地设为True 。示例代码:
import torch
x = torch.tensor([[1,2,3],[4,5,6]],dtype=torch.float32) # 不指定dtype=torch.float32,计算x.mean(dim=0)将会报错RuntimeError: mean(): input dtype should be either floating point or complex dtypes.Got Long instead
x.mean(dim=0)
>>>
tensor([2.5000, 3.5000, 4.5000])
x.mean(dim=1)
>>>
tensor([2., 5.])
# 设置参数keepdim是结果保持正确的维度
x.mean(dim=0,keepdim=True)
>>>
tensor([[2.5000, 3.5000, 4.5000]])
x.mean(dim=1,keepdim=True)
>>>
tensor([[2.],
[5.]])
张量的拼接操作(torch.cat)也是类似的,通过指定维度dim,获得不同的拼接结果。
张量的拼接运算关于dim的规则:
在做张量的运算操作时,dim设定了哪个维,就会遍历这个维去做运算(也称“沿着该维运算”),其他为顺序不变。
x = torch.tensor([[1,2,3],[4,5,6]],dtype=torch.float32)
y = torch.tensor([[7,8,9],[10,11,12]],dtype=torch.float32)
torch.cat((x,y),dim=0)
>>>
tensor([[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.],
[10., 11., 12.]])
torch.cat((x,y), dim=1)
>>>
tensor([[ 1., 2., 3., 7., 8., 9.],
[ 4., 5., 6., 10., 11., 12.]])
有时为了适配某些运算,需要对一个张量进行升维或降维。具体而言:
升维: 就是通过torch.unsqueeze(input, dim, out=None)函数,对输入张量的dim位置插入维度1,并返回一个新的张量。与索引相同,dim的值也可以为负数。
降维: 就是通过torch.squeeze(input, dim=None, out=None)函数。