可以直接看3.例子,就明显1和2说的啥了
在
stack() cat()
区别参考这个链接关于torch.stack(),但是本文主要说
前言
该函数总的来说和
通常用来,
就像下图两个人的关系一样(窃喜):
1. cat()官方解释
----
函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。
1 | outputs = torch.cat(inputs, dim=0) # → Tensor |
参数
----inputs : 待连接的张量序列,可以是任意相同
----dim : 选择的扩维, 必须在
注解 :
2. 重点难点
- 输入数据必须是序列,序列中数据是任意相同的
shape 的同类型tensor - 维度不可以超过输入数据的任一个张量的维度
3.举例子
- 准备数据,每个的
shape 都是[2,3]
1 2 3 4 5 6 | # x1 x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int) x1.shape # torch.Size([2, 3]) # x2 x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int) x2.shape # torch.Size([2, 3]) |
- 合成
inputs
1 2 3 4 5 6 7 8 | 'inputs为2个形状为[2 , 3]的矩阵 ' inputs = [x1, x2] print(inputs) '打印查看' [tensor([[11, 21, 31], [21, 31, 41]], dtype=torch.int32), tensor([[12, 22, 32], [22, 32, 42]], dtype=torch.int32)] |
3.查看结果, 测试不同的
1 2 3 4 5 6 7 8 | In [1]: torch.cat(inputs, dim=0).shape Out[1]: torch.Size([4, 3]) In [2]: torch.cat(inputs, dim=1).shape Out[2]: torch.Size([2, 6]) In [3]: torch.cat(inputs, dim=2).shape IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2) |
大家可以复制代码运行一下就会发现其中规律了。
总结
通常用来,把