torch.cat()函数的官方解释,和全面详细的注解

可以直接看3.例子,就明显1和2说的啥了

pytorch中,常见的拼接函数主要是两个,分别是:

  1. stack()
  2. cat()

区别参考这个链接关于torch.stack(),但是本文主要说cat()

前言

该函数总的来说和python内置函数cat()函数没有区别。

通常用来,torch.cat()是为了把函数torch.stack()得到tensor进行拼接而存在的。
就像下图两个人的关系一样(窃喜):
在这里插入图片描述

1. cat()官方解释

----torch.cat(inputs, dim=0) → Tensor

函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。

1
outputs = torch.cat(inputs, dim=0)  #  → Tensor

参数
----inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列
----dim : 选择的扩维, 必须在0len(inputs[0])之间,沿着此维连接张量序列。
注解 : torch.cat()可以看做 torch.split()torch.chunk()的反操作。

2. 重点难点

  1. 输入数据必须是序列,序列中数据是任意相同的shape的同类型tensor
  2. 维度不可以超过输入数据的任一个张量的维度

3.举例子

  1. 准备数据,每个的shape都是[2,3]
1
2
3
4
5
6
# x1
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x1.shape # torch.Size([2, 3])
# x2
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x2.shape  # torch.Size([2, 3])

  1. 合成inputs
1
2
3
4
5
6
7
8
'inputs为2个形状为[2 , 3]的矩阵 '
inputs = [x1, x2]
print(inputs)
'打印查看'
[tensor([[11, 21, 31],
         [21, 31, 41]], dtype=torch.int32),
 tensor([[12, 22, 32],
         [22, 32, 42]], dtype=torch.int32)]

3.查看结果, 测试不同的dim拼接结果

1
2
3
4
5
6
7
8
In    [1]: torch.cat(inputs, dim=0).shape
Out[1]: torch.Size([4,  3])

In    [2]: torch.cat(inputs, dim=1).shape
Out[2]: torch.Size([2, 6])

In    [3]: torch.cat(inputs, dim=2).shape
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

大家可以复制代码运行一下就会发现其中规律了。

总结

通常用来,把torch.stack得到tensor进行拼接而存在的。