Pytorch学习 (二十六)—- torch.scatter的使用

总说

一个非常有用的函数,主要是用于“group index”的操作。

先安装一下 https://github.com/rusty1s/pytorch_scatter

1
2
3
4
5
6
7
8
9
10
11
12
from torch_scatter import scatter
import torch

src = (torch.rand(2, 6, 2)*4).int()
index = torch.tensor([0, 1, 0, 1, 2, 1])

# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")

print(src)
print(index)
print(out)

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
tensor([[[1, 3],
         [3, 3],
         [3, 2],
         [2, 1],
         [1, 0],
         [0, 2]],

        [[0, 3],
         [3, 0],
         [2, 1],
         [2, 2],
         [3, 0],
         [0, 3]]], dtype=torch.int32)
tensor([0, 1, 0, 1, 2, 1])
tensor([[[4, 5],
         [5, 6],
         [1, 0]],

        [[2, 4],
         [5, 5],
         [3, 0]]], dtype=torch.int32)

这里沿着“dim=1”来分别找不同的index,对于“index 0”来说,在第“0”个和第“2”个位置出现,对应的数据是

1
[[1, 3], [0, 3]]  以及 [[3, 2], [2, 1]]

这里"sum",那么后,就变成了

1
[[4, 5], [2, 4]]

同理解释其他的index。