torch.stack的理解
官方的torch.stack说明文档:
1、进行stack的tensor的维度必须一致,即All tensors need to be of the same size;
2、dim的取值范围:以二维tensor为例,dim的取值范围为[0,1,2]
以具体例子说明stack的用法
a,b,c都是2*3的二维tensor,dim有三种取值情况,分别是dim=0,dim=1,dim=2;
torch.stack()会在指定的dim维度新增一个维度,和concatenate的区别:concatenate不会新增维度,只是将不同的tensor拼接在一起,stack是会新增维度的。
1、dim=0:
a,b,c每个tensor的维度都是2*3的,dim=0会在第一维新增维度,大小为3(因为是3个tensor进行stack),最终的维度为3*2*3;由于是在0维新增维度,所以stack后a,b,c每个tensor是保持不变的;
2、dim=1:
在第二维新增一个维度,最终stack后的tensor维度是2*3*3;效果等价于从一个数据矩阵变成一个数据立方体;第一维大小为2,且a在最前面,表示在最终的tensor中第一维的2是以a为基准的;需要在第二维新增一个大小为3的维度,所以就把a,b,c的第一行拼接在一起形成一个3*3的tensor;另外还要注意一点,stack的结果和tensor的stack顺序是有关的,比如:
这样就是以b为基准进行维度变换了;
3、dim=2:
最终的维度为2*3*3,但是需要注意一下,第一个3表示a的2*3中的3,第二个3是新增的维度大小;比较一下dim=1和dim=2的结果,可以发现其中的元素是互为转置的;这个也好理解,因为扩充的维度位置不一样,那么对应的数据位置也要相应的改变。
写了这么多,最重要的还是要把这种在二维空间中用List形式表达出来的三维tensor和空间中的tensor cube给对应起来,这样理解起来就会容易很多。