Axial Attention 和 Criss-Cross Attention及其代码实现


Axial Attention 和 Criss-Cross Attention及其代码实现

?

文章目录

    • Axial Attention 和 Criss-Cross Attention及其代码实现
      • 1 Criss - Cross Attention介绍
        • 1.1 引言
        • 1.2 理论实现
          • 1.2.1 获取权重A
          • 1.2.2 Affinity操作
          • 1.3.3 全部信息获取
        • 1.3 代码实现
          • 1.3.1 官方实现
          • 1.3.2 纯pytorch实现
      • 2 Axial Attention 介绍
        • 2.1 引言
        • 2.2 理论实现
        • 2.3 代码实现
          • 2.3.1 Row Attention
          • 2.3.2 Col Attention
          • 2.3.4 总结
        • 2.4 开源的Axial Attention
      • 参考链接

? 传统的卷积操作感受野受限于卷积核大小和stride, 因此对一个尺寸较大的二维图片来说,其左上角的像素点如果要能和右下角的像素点发生联系的话,必须通过不断的加深卷积层层数。而计算机视觉中的self Attention机制能够实现目标像素点与任意一个像素点之间的关联,即融合了全局信息。但目前Self Attention中,以non-local attention为例,计算量较大,尤其在特征图很大时,计算效率非常低下。 那么怎么能既融合全局信息,保持长距离的空间依赖性,又降低计算量呢?Axial Attention 和 Criss-Cross Attention可以很好的解决这个问题。

1 Criss - Cross Attention介绍

1.1 引言

CCNet(Criss Cross Network)的核心是重复十字交叉注意力模块。该模块通过两次CC Attention,可以实现目标特征像素点与特征图中其他所有点之间的相互关系,并用这样的相互关系对目标像素点的特征进行加权,以此获得更加有效的目标特征。

  • non-local 模型中, 因为要考虑特征图每个像素点与所有像素点的关系,时间复杂度和空间复杂度为 O(HW*HW)。
  • 在CC Attention模块中,计算特征图中每个像素点与其对应的行列的像素点的关系,时间复杂度和空间复杂度O(HW*(H + W - 1)) ,相比前者降低了一个1~2个数量级。

在这里插入图片描述

1.2 理论实现

还是基于***self attention***的思路,使用Q和K向量来确定权重,再与V值取权重和。

1
2
3
4
H: (batch,c1,h,w) #输入特征图
Q: (batch,c2,h,w) #Query查询向量
K: (batch,c1,h,w) #Key 键值向量
V: (batch,c1,h,w) #Value 值向量
1.2.1 获取权重A

暂先不考虑Batch

  • step1: 取Q中特征图中某一像素点的所有通道值:q = Q(i,j) , size = (1, c2)
  • step2: 取K特征图中与q同一行和同一列的所有像素点的所有通道值, 交叉位置取了两次,只选一次. k的 size = (c2, h+w - 1)
  • step3: q*k, 得到q_atten, size = (1, h+w - 1), 并对这(h + w -1) 个值进行softmax操作,即权重和为1.
  • step4: 对Q中的所有像素点重复step2 和 step3, 即得到了每个像素点的归一权重。此时atten.size = (batch, h,w, h + w - 1)
1.2.2 Affinity操作

上面已经获取了Atten — (batch, h, w, h + w - 1)。接下来将权重施加在V上。 先不考虑batch

  • step1: 获取Atten中某个像素点的所有权重, A = Atten(i,j) , size = (1, h + w -1)
  • step2: 取V的某一通道Cn 的特征图Vn, size = (h, w) , 选取Vn上与A对应位置的同一行和同一列的数值,记作vn,size = (1, h + w - 1)
  • step3: vn 与 A.T 相乘,即得到加权后的vn值,size = (1,1)
  • step4: 对V中的所有通道重复step2 和 step3操作。
  • step5: 对Atten中的所有像素点重复上述4步操作。
  • step6: 残差网络:H‘ = CCAtten(H) + H

ICCV2019语义分割文章CCNet详解 该文章有很好的动图展示,有助理解。

在这里插入图片描述

1.3.3 全部信息获取

一个CCAttention,只能获取当前位置上同一行和一列的信息,如果两个叠加两个CCAttention,就可以获取全局信息。

在这里插入图片描述

如目标像素点是(Ux,Uy) ,想要获取(θx, θy)的关系。

loop1: 像素点(Ux, θy)和 (θx, Uy)通过一次CCAtten 可以建立与(θx, θy)的关系;

loop2: 像素点是(Ux,Uy) 通过CCAtten可以获取与 像素点(Ux, θy)和 (θx, Uy)的联系,从而间接取得与(θx, θy)的联系。

即通过两次CCAtten , 可以建立目前像素点与任意像素点的信息融合。

1.3 代码实现

1.3.1 官方实现

虽然上面给出了对CCAttention 的理论逻辑的理解,但如果按照该逻辑使用循环进行代码设计,很占计算资源,计算速度也会很慢,而且反向传播也不易做到。我有查看网上的一些实现(包括官方源码),似乎需要自定义的cuda算子,这个扩展性并不友好。

https://github.com/speedinghzl/CCNet

在这里插入图片描述

1.3.2 纯pytorch实现

参考: https://github.com/yearing1017/CCNet_PyTorch ,我在这里添加了详细的注释,有助理解更透彻。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def INF(B,H,W):
    '''
    生成(B*W,H,H)大小的对角线为inf的三维矩阵
    Parameters
    ----------
    B: batch
    H: height
    W: width
    '''
    return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)


class CC_module(nn.Module):
   
    def __init__(self,in_dim, device):
        '''
        Parameters
        ----------
        in_dim : int
            channels of input
        '''
        super(CC_module, self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = Softmax(dim=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1)).to(device)
        self.device = device
         
    def forward(self, x):

        m_batchsize, _, height, width = x.size()
       
        proj_query = self.query_conv(x) #size = (b,c2,h,w), c1 = in_dim, c2 = c1 // 8
       
        #size = (b*w, h, c2)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
       
        #size = (b*h, w, c2)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
       
        proj_key = self.key_conv(x) #size = (b,c2,h,w)
       
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) #size = (b*w,c2,h)
       
        proj_value = self.value_conv(x) #size = (b,c1,h,w)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) #size = (b*w,c1,h)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) #size = (b*h,c1,w)
       
        #size = (b*w, h,h) ,其中[:,i,j]表示Q所有W的第Hi行的所有通道值与K上所有W的第Hj列的所有通道值的向量乘积
        energy_H = torch.bmm(proj_query_H, proj_key_H)
       
        #size = (b,h,w,h) #这里为什么加 INF并没有理解
        energy_H = (energy_H + self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
       
        #size = (b*h,w,w),其中[:,i,j]表示Q所有H的第Wi行的所有通道值与K上所有H的第Wj列的所有通道值的向量乘积
        energy_W = torch.bmm(proj_query_W, proj_key_W)
        energy_W = energy_W.view(m_batchsize,height,width,width) #size = (b,h,w,w)
       
        concate = self.softmax(torch.cat([energy_H, energy_W], 3)) #size = (b,h,w,h+w) #softmax归一化
        #concate = concate * (concate>torch.mean(concate,dim=3,keepdim=True)).float()
        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height) #size = (b*w,h,h)
        #print(concate)
        #print(att_H)
        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width) #size = (b*h,w,w)
       
        #size = (b*w,c1,h) #[:,i,j]表示V所有W的第Ci行通道上的所有H 与att_H的所有W的第Hj列的h权重的乘积  
        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1))
        out_H = out_H.view(m_batchsize,width,-1,height).permute(0,2,3,1)  #size = (b,c1,h,w)
       
        #size = (b*h,c1,w) #[:,i,j]表示V所有H的第Ci行通道上的所有W 与att_W的所有H的第Wj列的W权重的乘积  
        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1))
        out_W = out_W.view(m_batchsize,height,-1,width).permute(0,2,1,3) #size = (b,c1,h,w)
        #print(out_H.size(),out_W.size())
       
        return self.gamma*(out_H + out_W) + x


if __name__ == '__main__':
   
    # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cuda:0' if torch.cuda.device_count() > 1 else 'cpu')
    model = CC_module(8,device)
    x = torch.randn(4, 8, 20, 20).to(device)
    out = model(x).(device)
    print(out.shape)

需要注意的是: 这里的self-attention, 主要是空间注意力,并没有涉及通道注意力。

2 Axial Attention 介绍

2.1 引言

Axial Attention,即轴向注意力。之前关注到Axial Attention,是在谷歌的天气预报模型Metnet中有使用到。

在这里插入图片描述

于是去翻看了Axial Attention注意力。

在这里插入图片描述

2.2 理论实现

在这里插入图片描述

这里不谈他后续怎么设计的***Transformer***的,单纯的Axial Attention来看,其实有点类似之前的CC-Attention。

  • CC-Attention 的感受野是与目标像素的同一行和同一列的(H + W - 1)个像素
  • Axial Attention 的感受野是目标像素的同一行(或者同一列) 的W(或H)个像素

具体的思路也还是self-attention。理论实现方法与CC-Attention大同小异,这里就不赘述了。

2.3 代码实现

这里先根据个人理解,给出Axial Attention中的 Row-Attention 和 Col-Attention。

2.3.1 Row Attention
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#实现轴向注意力中的 row Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax

# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.device_count() > 1 else 'cpu')

class RowAttention(nn.Module):
   
    def __init__(self, in_dim, q_k_dim, device):
        '''
        Parameters
        ----------
        in_dim : int
            channel of input img tensor
        q_k_dim: int
            channel of Q, K vector
        device : torch.device
        '''
        super(RowAttention, self).__init__()
        self.in_dim = in_dim
        self.q_k_dim = q_k_dim
        self.device = device
       
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.in_dim, kernel_size=1)
        self.softmax = Softmax(dim=2)
        self.gamma = nn.Parameter(torch.zeros(1)).to(self.device)
       
    def forward(self, x):
        '''
        Parameters
        ----------
        x : Tensor
            4-D , (batch, in_dims, height, width) -- (b,c1,h,w)
        '''
       
        ## c1 = in_dims; c2 = q_k_dim
        b, _, h, w = x.size()
       
        Q = self.query_conv(x) #size = (b,c2, h,w)
        K = self.key_conv(x)   #size = (b, c2, h, w)
        V = self.value_conv(x) #size = (b, c1,h,w)
       
        Q = Q.permute(0,2,1,3).contiguous().view(b*h, -1,w).permute(0,2,1) #size = (b*h,w,c2)
        K = K.permute(0,2,1,3).contiguous().view(b*h, -1,w)  #size = (b*h,c2,w)
        V = V.permute(0,2,1,3).contiguous().view(b*h, -1,w)  #size = (b*h, c1,w)
       
        #size = (b*h,w,w) [:,i,j] 表示Q的所有h的第 Wi行位置上所有通道值与 K的所有h的第 Wj列位置上的所有通道值的乘积,
        # 即(1,c2) * (c2,1) = (1,1)
        row_attn = torch.bmm(Q,K)
        ########
        #此时的 row_atten的[:,i,0:w] 表示Q的所有h的第 Wi行位置上所有通道值与 K的所有行的 所有列(0:w)的逐个位置上的所有通道值的乘积
        #此操作即为 Q的某个(i,j)与 K的(i,0:w)逐个位置的值的乘积,得到行attn
        ########
       
        #对row_attn进行softmax
        row_attn = self.softmax(row_attn) #对列进行softmax,即[k,i,0:w] ,某一行的所有列加起来等于1,
       
        #size = (b*h,c1,w) 这里先需要对row_atten进行 行列置换,使得某一列的所有行加起来等于1
        #[:,i,j]即为V的所有行的某个通道上,所有列的值 与 row_attn的行的乘积,即求权重和
        out = torch.bmm(V,row_attn.permute(0,2,1))
       
        #size = (b,c1,h,2)
        out = out.view(b,h,-1,w).permute(0,2,1,3)
       
        out = self.gamma*out + x

        return out
   
#实现轴向注意力中的 cols Attention
x = torch.randn(4, 8, 16, 20).to(device)
row_attn = RowAttention(in_dim = 8, q_k_dim = 4,device = device).to(device)
print(row_attn(x).size())
2.3.2 Col Attention
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#实现轴向注意力中的 column Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax

# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.device_count() > 1 else 'cpu')

class ColAttention(nn.Module):
   
    def __init__(self, in_dim, q_k_dim, device):
        '''
        Parameters
        ----------
        in_dim : int
            channel of input img tensor
        q_k_dim: int
            channel of Q, K vector
        device : torch.device
        '''
        super(ColAttention, self).__init__()
        self.in_dim = in_dim
        self.q_k_dim = q_k_dim
        self.device = device
       
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.in_dim, kernel_size=1)
        self.softmax = Softmax(dim=2)
        self.gamma = nn.Parameter(torch.zeros(1)).to(self.device)
       
    def forward(self, x):
        '''
        Parameters
        ----------
        x : Tensor
            4-D , (batch, in_dims, height, width) -- (b,c1,h,w)
        '''
       
        ## c1 = in_dims; c2 = q_k_dim
        b, _, h, w = x.size()
       
        Q = self.query_conv(x) #size = (b,c2, h,w)
        K = self.key_conv(x)   #size = (b, c2, h, w)
        V = self.value_conv(x) #size = (b, c1,h,w)
       
        Q = Q.permute(0,3,1,2).contiguous().view(b*w, -1,h).permute(0,2,1) #size = (b*w,h,c2)
        K = K.permute(0,3,1,2).contiguous().view(b*w, -1,h)  #size = (b*w,c2,h)
        V = V.permute(0,3,1,2).contiguous().view(b*w, -1,h)  #size = (b*w,c1,h)
       
        #size = (b*w,h,h) [:,i,j] 表示Q的所有W的第 Hi行位置上所有通道值与 K的所有W的第 Hj列位置上的所有通道值的乘积,
        # 即(1,c2) * (c2,1) = (1,1)
        col_attn = torch.bmm(Q,K)
        ########
        #此时的 col_atten的[:,i,0:w] 表示Q的所有W的第 Hi行位置上所有通道值与 K的所有W的 所有列(0:h)的逐个位置上的所有通道值的乘积
        #此操作即为 Q的某个(i,j)与 K的(i,0:h)逐个位置的值的乘积,得到列attn
        ########
       
        #对row_attn进行softmax
        col_attn = self.softmax(col_attn) #对列进行softmax,即[k,i,0:w] ,某一行的所有列加起来等于1,
       
        #size = (b*w,c1,h) 这里先需要对col_atten进行 行列置换,使得某一列的所有行加起来等于1
        #[:,i,j]即为V的所有行的某个通道上,所有列的值 与 col_attn的行的乘积,即求权重和
        out = torch.bmm(V,col_attn.permute(0,2,1))
       
        #size = (b,c1,h,w)
        out = out.view(b,w,-1,h).permute(0,2,3,1)
       
        out = self.gamma*out + x

        return out
   
#实现轴向注意力中的 cols Attention
x = torch.randn(4, 8, 16, 20).to(device)
col_attn = ColAttention(8, 4, device = device)

print(col_attn(x).size())
2.3.4 总结

单独使用Row Atten(或者Col Attention),即使是堆叠好几次,也是无法融合全局信息的。一般来说,Row Attention 和 Col Attention要组合起来使用才能更好的融合全局信息。

建议方式:

  • 方法1:out = RowAtten(x) + ColAtten(x)
  • 方法2:x1 = RowAtten(x), out = ColAtten(x1)
  • 方法3:x1 = ColAtten(x), out = RowAtten(x1)

这样的一次 out 类似于 一次CCAtten(x) 操作 。

所以一般至少需要迭代两次上述的任意方法,才能融合到全局信息。

2.4 开源的Axial Attention

github上有人开源了Axial Attention,并且灵活度很高。 https://github.com/lucidrains/axial-attention ,直接安装使用即可。

  • 适用图片
  • 适用视频
  • 适用通道注意力
  • 使用了多头结构
  • 适用Transformer结构
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
pip install axial_attention #安装

#Image
import torch
from axial_attention import AxialAttention

img = torch.randn(1, 3, 256, 256)

attn = AxialAttention(
    dim = 3,             # embedding dimension
    dim_index = 1,       # where is the embedding dimension
    dim_heads = 32,      # dimension of each head. defaults to dim // heads if not supplied
    heads = 1,           # number of heads for multi-head attention
    num_dimensions = 2,  # number of axial dimensions (images is 2, video is 3, or more)
)

img_attn = attn(img)
print(img_attn.size())


#%%
#Channel-last image latents
import torch
from axial_attention import AxialAttention

img = torch.randn(1, 20, 20, 512)

attn = AxialAttention(
    dim = 512,           # embedding dimension
    dim_index = -1,      # where is the embedding dimension
    heads = 8,           # number of heads for multi-head attention
    num_dimensions = 2,  # number of axial dimensions (images is 2, video is 3, or more)
)

img_attn = attn(img)
print(img_attn.size())

#%%
#Video
import torch
from axial_attention import AxialAttention

video = torch.randn(1, 5, 10, 20, 20)

attn = AxialAttention(
    dim = 10,           # embedding dimension
    dim_index = 2,       # where is the embedding dimension
    heads = 5,           # number of heads for multi-head attention
    num_dimensions = 3,  # number of axial dimensions (images is 2, video is 3, or more)
)

video_atten = attn(video)
print(video_atten.size())

#%%
# Image Transformer, with reversible network

import torch
from torch import nn
from axial_attention import AxialImageTransformer

conv1x1 = nn.Conv2d(3, 128, 1)

transformer = AxialImageTransformer(
    dim = 128,
    depth = 12,
    reversible = True
)

img = torch.randn(1, 3, 20, 20)

img1 = transformer(conv1x1(img))
print(img1.size())

目前这个源码还没有啃清楚,有明白的同学欢迎交流。

参考链接

ICCV2019语义分割文章CCNet详解

CCNet: Criss-Cross Attention for Semantic Segmentation论文解读

https://github.com/yearing1017/CCNet_PyTorch/tree/master/CCNet

https://github.com/speedinghzl/CCNet

论文:CCNet: Criss-Cross Attention for Semantic Segmentation

https://github.com/lucidrains/axial-attention

Axial Attention in Multidimensional Transformers

MetNet: A Neural Weather Model for Precipitation Forecasting