文章目录
- torch.Tensor.scatter_
- torch.gather
- 待补充
torch.Tensor.scatter_
torch.Tensor — PyTorch master documentation
??为避免混淆,这里用
??Writes all values from the tensor
??如果用self[a][b]表示self中的元素,则self的index指的是 位于dimension=0维度上的[a]、位于dimension=1维度上的[b]、位于dimension=2维度上的;
1 2 3 4 | # For a 3-D tensor, self is updated as: self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 |
??在上述示例中,等号左边表示的是each value in src的output index (也即是self的index),等号右边表示的是each value in src其本身的index;
??从右边往左边看,src[i][j][k]表示的是a value in src,
??if dim == 0, 则由于
- 右边 src的[i] 位于dimension=0维度上,该维度==0,故the output index is specified by the corresponding value in param_index,即由param_index[i][j][k]指定; the corresponding value的含义是src[i][j][k]位于src中哪个位置,就取param_index中哪个位置的值;因此the size of param_index应该与the size of src一致;
- 右边 src的[j] 位于dimension=1维度上,该维度!==0,故the output index is specified by its index in src;
- 右边 src的[k] 位于dimension=2维度上,该维度!==0,故the output index is specified by its index in src;
??因此,self[index[i][j][k]][j][k] = src[i][j][k]。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | Example: >>> x = torch.rand(2, 5) >>> x tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004], [ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]]) >>> y = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) >>> y tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004], [ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000], [ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]]) >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23) >>> z tensor([[ 0.0000, 0.0000, 1.2300, 0.0000], [ 0.0000, 0.0000, 0.0000, 1.2300]]) |
??此例中y is a 2-D tensor,由于
- 对于x[0][0]=0.3992,将其写入self[param_index[0][0]][0]中,由于param_index[0][0]=0,即写入self[0][0]中;
- 对于x[0][1]=0.2908,将其写入self[param_index[0][1]][1]中,由于param_index[0][1]=1,即写入self[1][1]中;
- 对于x[0][2]=0.9044,将其写入self[param_index[0][2]][2]中,由于param_index[0][2]=2,即写入self[2][2]中;
- 依次类推…
- 对于x[1][0]=0.5735,将其写入self[param_index[1][0]][0]中,由于param_index[1][0]=2,即写入self[2][0]中;
- 依次类推…
(待阅读)PyTorch笔记之 scatter() 函数 - 那少年和狗20191127
(待阅读)pytorch中torch.Tensor.scatter用法_qq_39004117的博客 20190716
(待阅读)Pytorch笔记 scatter_ - listenviolet 20190525
torch.gather
待补充
待补充
待补充