目录
一、通过索引获取值
二、通过索引设置值
在PyTorch C++ API(libtorch)中对张量进行索引的方式与Python API的方式很相似。诸如None / ... / integer / boolean / slice / tensor的索引类型在C++ API里同样有效,这样就可以很方便的实现Python代码与C++代码的转换。主要的不同是将Python API里对张量的“[ ]”操作符转换成了以下形式:
1 2 3 | torch::Tensor::index ( ) // 获取值 torch::Tensor::index_put_ ( ) // 设置值 |
有关官方文档请看这里。下面通过举例说明libtorch与pytorch中的向量索引/切片的方式,左边为Python方式,右边为C++方式:
一、通过索引获取值
1、tensor[Ellipsis, ...] --> tensor.index({Ellipsis, "..."})
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | import torch a = torch.linspace(1,27,27).reshape(3, 3, 3) print(a) c = a[..., 2] print(c) #===================运行结果===============# tensor([[[ 1., 2., 3.], [ 4., 5., 6.], [ 7., 8., 9.]], [[10., 11., 12.], [13., 14., 15.], [16., 17., 18.]], [[19., 20., 21.], [22., 23., 24.], [25., 26., 27.]]]) tensor([[ 3., 6., 9.], [12., 15., 18.], [21., 24., 27.]]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | #include "iostream" #include "torch/script.h" int main() { torch::Tensor a = torch::linspace(1, 27, 27).reshape({3, 3, 3}); std::cout << a << std::endl; at::Tensor b = a.index({"...", 2}); std::cout << b << std::endl; return 0; } /****************输出结果******************/ (1,.,.) = 1 2 3 4 5 6 7 8 9 (2,.,.) = 10 11 12 13 14 15 16 17 18 (3,.,.) = 19 20 21 22 23 24 25 26 27 [ CPUFloatType{3,3,3} ] 3 6 9 12 15 18 21 24 27 [ CPUFloatType{3,3} ] |
2、tensor[1, 2] --> tensor.index({1, 2})
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import torch a = torch.linspace(1,27,27).reshape(3, 3, 3) print(a) c = a[1, 2] print(c) #===================运行结果=================# tensor([[[ 1., 2., 3.], [ 4., 5., 6.], [ 7., 8., 9.]], [[10., 11., 12.], [13., 14., 15.], [16., 17., 18.]], [[19., 20., 21.], [22., 23., 24.], [25., 26., 27.]]]) tensor([16., 17., 18.]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | #include "iostream" #include "torch/script.h" int main() { torch::Tensor a = torch::linspace(1, 27, 27).reshape({3, 3, 3}); std::cout << a << std::endl; at::Tensor b = a.index({1, 2}); std::cout << b << std::endl; return 0; } /*****************运行结果***************/ (1,.,.) = 1 2 3 4 5 6 7 8 9 (2,.,.) = 10 11 12 13 14 15 16 17 18 (3,.,.) = 19 20 21 22 23 24 25 26 27 [ CPUFloatType{3,3,3} ] 16 17 18 [ CPUFloatType{3} ] |
3、tensor[1::2] --> tensor.index({Slice(1, None, 2)})
1 2 3 4 5 6 7 8 9 | import torch a = torch.linspace(1, 6, 6) print(a) c = a[1::2] print(c) #==================运行结果==================# tensor([1., 2., 3., 4., 5., 6.]) tensor([2., 4., 6.]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | #include "iostream" #include "torch/script.h" using namespace torch::indexing; int main() { torch::Tensor a = torch::linspace(1, 6, 6); std::cout << a << std::endl; at::Tensor b = a.index({Slice(1, None, 2)}); std::cout << b << std::endl; return 0; } /*******************运行结果*********************/ 1 2 3 4 5 6 [ CPUFloatType{6} ] 2 4 6 [ CPUFloatType{3} ] |
3.5、tensor[..., 1:] --> tensor.index({"...", Slice(1)})
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import torch a = torch.linspace(1,27,27).reshape(3, 3, 3) b = a[..., 1:] print(b) #===============运行结果===================# tensor([[[ 2., 3.], [ 5., 6.], [ 8., 9.]], [[11., 12.], [14., 15.], [17., 18.]], [[20., 21.], [23., 24.], [26., 27.]]]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | #include "iostream" #include "torch/script.h" using namespace torch::indexing; int main() { torch::Tensor a = torch::linspace(1, 27, 27).reshape({3, 3, 3}); torch::Tensor b = a.index({"...", Slice(1)}); std::cout << b << std::endl; return 0; } /******************运行结果**********************/ (1,.,.) = 2 3 5 6 8 9 (2,.,.) = 11 12 14 15 17 18 (3,.,.) = 20 21 23 24 26 27 [ CPUFloatType{3,3,2} ] |
4、tensor[torch.tensor([1, 2])] --> tensor.index({torch::tensor({1, 2})})
1 2 3 4 5 6 7 8 9 10 | import torch a = torch.linspace(1,4,4) b = torch.tensor([0, 1, 3, 2]) c = a[b] print(a) print(c) #===============运行结果===============# tensor([1., 2., 3., 4.]) tensor([1., 2., 4., 3.]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | #include "iostream" #include "torch/script.h" using namespace torch::indexing; int main() { torch::Tensor a = torch::linspace(1, 4, 4); torch::Tensor b = torch::tensor({0, 1, 3, 2}); torch::Tensor c = a.index({b}); std::cout << a << std::endl; std::cout << b << std::endl; return 0; } /*******************运行结果********************/ 1 2 3 4 [ CPUFloatType{4} ] 0 1 3 2 [ CPULongType{4} ] |
二、通过索引设置值
1、tensor[1, 2] = 1 --> tensor.index_put_({1, 2}, 1)
1 2 3 4 5 6 7 8 9 10 11 | import torch a = torch.linspace(1,4,4).reshape(2, 2) print(a) a[1, 1] = 100 print(a) #==================运行结果=====================# tensor([[1., 2.], [3., 4.]]) tensor([[ 1., 2.], [ 3., 100.]]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | #include "iostream" #include "torch/script.h" using namespace torch::indexing; int main() { torch::Tensor a = torch::linspace(1, 4, 4).reshape({2, 2}); std::cout << a << std::endl; a.index_put_({1, 1}, 100); std::cout << a << std::endl; return 0; } /***************运行结果****************/ 1 2 3 4 [ CPUFloatType{2,2} ] 1 2 3 100 [ CPUFloatType{2,2} ] |
2、tensor[Ellipsis, ...] = 1 --> tensor.index_put_({Ellipsis, "..."}, 1)
1 2 3 4 5 6 7 8 9 10 11 | import torch a = torch.linspace(1,4,4).reshape(2, 2) print(a) a[..., 1] = 100 print(a) #====================运行结果=====================# tensor([[1., 2.], [3., 4.]]) tensor([[ 1., 100.], [ 3., 100.]]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | #include "iostream" #include "torch/script.h" using namespace torch::indexing; int main() { torch::Tensor a = torch::linspace(1, 4, 4).reshape({2, 2}); std::cout << a << std::endl; a.index_put_({"...", 1}, 100); std::cout << a << std::endl; return 0; } /***************运行结果****************/ 1 2 3 4 [ CPUFloatType{2,2} ] 1 100 3 100 [ CPUFloatType{2,2} ] |
3、tensor[torch.tensor([1, 2])] = 1 --> tensor.index_put_({torch::tensor({1, 2})}, 1)
1 2 3 4 5 6 7 8 9 10 | import torch a = torch.linspace(1,4,4) b = torch.tensor([0, 2]) print(a) a[b] = 100 print(a) #===============运行结果==================# tensor([1., 2., 3., 4.]) tensor([100., 2., 100., 4.]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | #include "iostream" #include "torch/script.h" using namespace torch::indexing; int main() { torch::Tensor a = torch::linspace(1, 4, 4); torch::Tensor b = torch::tensor({0, 2}); std::cout << a << std::endl; a.index_put_({b}, 100); std::cout << a << std::endl; return 0; } /*****************运行结果*****************/ 1 2 3 4 [ CPUFloatType{4} ] 100 2 100 4 [ CPUFloatType{4} ] |