PyTorch中的torch.Tensor.scatter_用法理解

文章目录

    • torch.Tensor.scatter_
    • torch.gather
    • 待补充

torch.Tensor.scatter_

torch.Tensor.scatter_(dim, index, src) → Tensor

torch.Tensor — PyTorch master documentation
torch.Tensor.scatter_

??为避免混淆,这里用scatter_(param_dim, param_index, src) → Tensor表示scatter_(dim, index, src) → Tensor
??Writes all values from the tensor src into self at the indices specified in the param_index tensor. For each value in src, its output index is specified by its index in src for dimension != param_dim and by the corresponding value in param_index for dimension = param_dim.

??如果用self[a][b]表示self中的元素,则self的index指的是 位于dimension=0维度上的[a]、位于dimension=1维度上的[b]、位于dimension=2维度上的;

1
2
3
4
# For a 3-D tensor, self is updated as:
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

??在上述示例中,等号左边表示的是each value in src的output index (也即是self的index),等号右边表示的是each value in src其本身的index;
??从右边往左边看,src[i][j][k]表示的是a value in src,
??if dim == 0, 则由于

  • 右边 src的[i] 位于dimension=0维度上,该维度==0,故the output index is specified by the corresponding value in param_index,即由param_index[i][j][k]指定; the corresponding value的含义是src[i][j][k]位于src中哪个位置,就取param_index中哪个位置的值;因此the size of param_index应该与the size of src一致;
  • 右边 src的[j] 位于dimension=1维度上,该维度!==0,故the output index is specified by its index in src;
  • 右边 src的[k] 位于dimension=2维度上,该维度!==0,故the output index is specified by its index in src;

??因此,self[index[i][j][k]][j][k] = src[i][j][k]。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Example:
>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> y = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
>>> y
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.2300]])

??此例中y is a 2-D tensor,由于param_dim=0,故self[param_index[i][j]][j] = src[i][j],因而有:

  • 对于x[0][0]=0.3992,将其写入self[param_index[0][0]][0]中,由于param_index[0][0]=0,即写入self[0][0]中;
  • 对于x[0][1]=0.2908,将其写入self[param_index[0][1]][1]中,由于param_index[0][1]=1,即写入self[1][1]中;
  • 对于x[0][2]=0.9044,将其写入self[param_index[0][2]][2]中,由于param_index[0][2]=2,即写入self[2][2]中;
  • 依次类推…
  • 对于x[1][0]=0.5735,将其写入self[param_index[1][0]][0]中,由于param_index[1][0]=2,即写入self[2][0]中;
  • 依次类推…

(待阅读)PyTorch笔记之 scatter() 函数 - 那少年和狗20191127

(待阅读)pytorch中torch.Tensor.scatter用法_qq_39004117的博客 20190716

(待阅读)Pytorch笔记 scatter_ - listenviolet 20190525

torch.gather

torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
待补充

待补充

待补充