Transformer原理及代码实现

概述

  • Transformer是2018年10月,Google发出一篇论文《BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding》提出,横扫了NLP领域11项任务的最佳成绩

  • Transformer的优势:能够利用分布式GPU的方式进行并行训练,在模型训练效率上大大提高;在分析预测更长文本时,捕获间隔较长的语义关联效果更好

  • Transformer模型可以用在机器翻译, 文本生成等.,同时又可以构建预训练语言模型,用于不同任务的迁移学习

Transformer网络架构

transformer的网络结构主要包括四个部分:输入部分,输出部分,编器,解码器
在这里插入图片描述

输入部分

源:源文本嵌入层及其位置编码器 目标:目标文本嵌入层及其位置编码器

结构图:
在这里插入图片描述

  • 文本嵌入层的作用:目的是将文本word2id的数字转变为以向量的方式表示

文本切入层代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import math
class embedded(nn.Module):
    def __init__(self,vocab,emb_dim):
        super().__init__()
        self.vocab=vocab
        self.emb_dim=emb_dim
        self.embedding=nn.Embedding(self.vocab,self.emb_dim)
   
    def forward(self,input):
        embedded=self.embedding(input)     

        return embedded*math.sqrt(self.emb_dim)

if __name__=='__main__':
    vocab=32
    emb_dim=512
    input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
   
    #input=torch.randn(2,3,4) # RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but
 got torch.FloatTensor instead (whil
    emb=embedded(vocab,emb_dim)
    emb_output=emb(input)
    print(emb_output)
运行结果
(nlp) [root@bhs rnn]# python emb.py
tensor([[[ 12.1973,  -3.8676,   0.8053,  ..., -19.2736,  11.9937, -15.9764],
         [ 53.9954,  80.1649, -31.2148,  ...,  -2.2879, -11.6445, -42.3147],
         [-29.1181,  20.0065,   3.5267,  ..., -24.3313,  10.8453, -14.8268],
         [-10.6244, -14.7745,  26.8477,  ..., -17.3613,  33.1439,   5.0513]],

        [[-29.1181,  20.0065,   3.5267,  ..., -24.3313,  10.8453, -14.8268],
         [ 14.4021,  13.1048,  -3.9136,  ..., -12.9840, -62.7084,  14.2422],
         [ 53.9954,  80.1649, -31.2148,  ...,  -2.2879, -11.6445, -42.3147],
         [-28.0308,  26.8346,  -8.4742,  ..., -27.7757, -27.8436, -42.3201]]],
       grad_fn=<MulBackward0>)
  • 位置编码器

编码器部分

由N个编码器组成,每一个编码器有两个子层连接,一个是多头自注意力,规范化层及残差单元,另一个是前馈层,规范化层级残差单元
结构图如下
在这里插入图片描述
编码器中注意力按照如下计算规则的代码实现
在这里插入图片描述
逻辑图:
在这里插入图片描述

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def attion(query,key,value,mask=None,dropout=None):

        d_value=query.size()[-1]
        scores=torch.matmul(query,key.transpose(-2,-1))/math.sqrt(d_value)

        if mask is not None:
                scores=scores.masked_fill(mask==0,-1e9)
        p_attn=F.softmax(scores,dim=-1)
        if dropout is not None:
                p_attn=dropout(p_att)

        return torch.matmul(p_attn,value),p_attn

query=key=value=pe_result
attn,p_attn=attion(query,key,value)
print(attn)