torch.argmax(input) → LongTensor
说明:返回输入input张量中所有元素最大值对应的索引(如果有多个相同则返回第一个)
input(Tensor):输入tensor
example:
- >input = torch.randn(2,3,4)
- tensor([[[-0.5390, -1.0114, 0.4355, -1.1405],
- [ 0.6028, -0.4489, 0.5327, 0.2840],
- [-0.7695, 0.9981, -0.1962, -0.6833]],
-
- [[-1.2388, -0.0077, 1.0434, 0.3299],
- [ 0.1485, 0.7355, 0.1245, -1.2814],
- [ 0.7018, -0.8662, 2.3299, -1.0735]]])
-
- >output = torch.argmax(input)
- tensor(22)
-
- #索引22位置的2.3299最大
torch.argmax(input, dim, keepdim=False) → LongTensor
说明:返回在指定维度最大值的索引
input(Tensor):输入张量
dim(int):要减少的维度,如果是None(或不填),那么就是返回平铺输入的argMax
keepdim(bool):输出张量是否保留该维度
example
- >input = torch.randn(2,3,4)
- tensor([[[ 1.1332, -1.1633, -0.0305, -0.1000],
- [-0.1684, 1.4696, 0.3790, 0.2453],
- [-0.9125, 0.1091, 0.8701, 0.7641]],
-
- [[ 0.9497, 1.3644, 0.9301, 0.0711],
- [-0.2496, -0.9306, -0.4644, -1.3337],
- [ 0.0058, -0.2439, 0.9298, 0.2472]]])
-
- >output = torch.argmax(input,dim=0)
- tensor([[0, 1, 1, 1],
- [0, 0, 0, 0],
- [1, 0, 1, 0]])
-
- #输入3维张量(2,3,4),dim=0表示把最外围的2消除压缩成1,然后找到dim=0中2个值的最大值,比如dim(0)中第一组值1.1332和0.9497,那么取最大就是0,以此类推得到(3,4)的张量
- #假如keepdim=True,那么dim(0)维度会被保留,(2,3,4)变成(1,3,4)
-
- >output = torch.argmax(input,dim=1)
- tensor([[0, 1, 2, 2],
- [0, 0, 0, 2]])
-
- #同上,dim=1表示把第二个维度3消除压缩成1,比如dim(1)中第一组值[1.1332,-0.1684,-0.9125],最大值所以0,其他以此类推得到(2,4)的张量
- #假如keepdim=True,那么dim(1)维度会被保留,(2,3,4)变成(2,1,4)
-
- >output = torch.argmax(input,dim=2)
- tensor([[0, 1, 2],
- [1, 0, 2]])
-
- #也是同上,不作分析了
torch.argmin(input, dim=None, keepdim=False) → LongTensor
说明:返回在指定维度最小值的索引
input(Tensor):输入张量
dim(int):要减少的维度,如果是None(或不填),那么就是返回平铺输入的argMin
keepdim(bool):输出张量是否保留该维度
用法参考就是argMax的第二种格式,只是取小值的索引