总说
一个非常有用的函数,主要是用于“group index”的操作。
先安装一下 https://github.com/rusty1s/pytorch_scatter
1 2 3 4 5 6 7 8 9 10 11 12 | from torch_scatter import scatter import torch src = (torch.rand(2, 6, 2)*4).int() index = torch.tensor([0, 1, 0, 1, 2, 1]) # Broadcasting in the first and last dim. out = scatter(src, index, dim=1, reduce="sum") print(src) print(index) print(out) |
输出
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | tensor([[[1, 3], [3, 3], [3, 2], [2, 1], [1, 0], [0, 2]], [[0, 3], [3, 0], [2, 1], [2, 2], [3, 0], [0, 3]]], dtype=torch.int32) tensor([0, 1, 0, 1, 2, 1]) tensor([[[4, 5], [5, 6], [1, 0]], [[2, 4], [5, 5], [3, 0]]], dtype=torch.int32) |
这里沿着“dim=1”来分别找不同的index,对于“index 0”来说,在第“0”个和第“2”个位置出现,对应的数据是
1 | [[1, 3], [0, 3]] 以及 [[3, 2], [2, 1]] |
这里"sum",那么后,就变成了
1 | [[4, 5], [2, 4]] |
同理解释其他的index。