RNN与反向传播算法(BPTT)的理解


RNN是序列建模的强大工具。
今天主要搬运两天来看到的关于RNN的很好的文章:

  • Anyone Can Learn To Code an LSTM-RNN in Python (Part 1: RNN) 这个博客相当赞,从Toy Code学习确实是个很好的方法(便于理解抓住核心)!。
  • David 9的博客,有上面的译文,还有一些包括GAN(生成对抗网络)的Toy code.
  • RNN反向传播算法(BPTT)的理解和介绍,链接的文章非常清晰!归根结底还是BP,只是隐层更新是牵涉到未来时间戳的输出(因为当前隐层的输出会被记忆并影响未来)。
  • Andrej Karpathy blog on RNN

PS: 第一个链接中的Toy Code做一些说明

图片名称

之所以要循环8(binary_dim=8)次,是因为输入是2维的(a和b各输入一个bit),那么,每个bit只会影响8个时间戳。因此要注意RNN的训练,应该以每一个完整的序列(这里就是a和b两个八位二进制数)作为一个training sample,而非以每一次输入(2 bits)作为一个sample;同样的在反向传播时,也同样遵循这个原则,此处因为每次输入会影响8个时间戳(或者说每8次输入为一个完整的training sample),所以要循环8次。

再然后,第99行(五角星处)的隐层delta更新法则与上面给出的RNN反向传播算法BPTT一文中的下图正好一致!

这里写图片描述

先写这么多。