torch.where(condition, x, y) → Tensor
三个参数,condition是选择条件,x是第一个源头1,y是源头2
原理:当满足condition条件时,输出tensor对应位置为x中该位置元素,反之为y中该位置元素。

cond = torch.tensor([[0.6, 0.7],[0.8, 0.2]])
a = torch.zeros(2,2)
b = torch.ones(2,2)
torch.where(cond>0.5, a, b)

torch.gather(input ,dim, index)
其中,input是储存待取元素的tensor,
dim是操作的维度,
index里各元素是索引值,从input中取每个索引值对应的元素
实例理解:
神经网络作手写数字分类(这里数字设为100-109),模型输出结果shape:(4,10),其中10表示每个目标对应十个数字分别的预测概率大小。想得到最终的结果是,对每个目标,预测最可能的三个数字结果,故最终shape为(4,3)
prob = torch.rand(4, 10)
idx = prob.topk(3, dim=1)
idx
idx = idx[1]
idx
label = torch.arange(10)+100
torch.gather(label.expand(4, 10), dim=1, index=idx.long())
# .long()函数作用是向下取整,得到的不是float格式。在这里用不用效果都一样


最终的结果含义是:比如第一行,对第一个目标预测,该目标最可能的数字是106、102和107中的某一个。