- torch_scatter.scatter(src: Tensor, index: Tensor, dim: int = - 1,
- out: Optional[Tensor] = None,
- dim_size: Optional[int] = None,
- reduce: str = 'sum')→ Tensor
src – The source tensor. ( 源张量)
index – The indices of elements to scatter.(要分散的元素的索引)
dim – The axis along which to index. (default: -1) (要索引的轴(默认值:-1))
out – The destination tensor. (目标张量)
dim_size – If out is not given, automatically create output with size dim_size at dimension dim. If dim_size is not given, a minimal sized output tensor according to index.max() + 1 is returned.(如果未给出out,则在 dim 处自动创建尺寸为 dim_size 的输出。如果没有给出 dim_size,则返回根据 index.max() + 1 的最小尺寸输出张量);
reduce – The reduce operation ("sum", "mul", "mean", "min" or "max"). (default: "sum") (reduce 操作(“ sum”、“ mul”、“ mean”、“ min”或“ max”);
直观的理解:
对于三维矩阵:
- y = y.scatter(dim,index,src)
-
- #则结果为:
- y[ index[i][j][k] ] [j][k] = src[i][j][k] # if dim == 0
- y[i] [ index[i][j][k] ] [k] = src[i][j][k] # if dim == 1
- y[i][j] [ index[i][j][k] ] = src[i][j][k] # if dim == 2
对于二维矩阵:
- y = y.scatter(dim,index,src)
-
- #则:
- y [ index[i][j] ] [j] = src[i][j] #if dim==0
- y[i] [ index[i][j] ] = src[i][j] #if dim==1
ps: index的维度,必须和src维度相同;
举例:
- >>> src = torch.randn(3, 3)
- >>> src
- tensor([[-1.8801, 0.9740, 1.2865],
- [ 0.3140, 1.2396, -1.3452],
- [-0.8937, 0.6916, -2.0134]])
- >>> y = y.scatter_(0,index,src)
- >>> index = torch.tensor([[0, 1, 0],[1,0,1],[2,1,0]])
- >>> index
- tensor([[0, 1, 0],
- [1, 0, 1],
- [2, 1, 0]])
- >>> y = y.scatter_(0,index,src)
- >>> y
- tensor([[-1.8801, 1.2396, -2.0134],
- [ 0.3140, 0.6916, -1.3452],
- [-0.8937, 0.0000, 0.0000],
- [ 0.0000, 0.0000, 0.0000]])
