注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。
Chapter8 Recurrent Neural Networks
8.7 Backpropagation Through Time
通过时间反向传播(backpropagation through time,BPTT)是循环神经网络中反向传播技术的一个特定应用,它要求我们将循环神经网络的计算图一次展开一个时间步,以获得模型变量和参数之间的依赖关系,然后,基于链式法则,应用反向传播来计算和存储梯度。由于序列可能相当长,因此依赖关系也可能相当长,在下文中,我们将阐明计算过程会发生什么以及如何在实践中解决它们。
8.7.1 RNN’s Gradient Analysis
我们从一个描述循环神经网络工作原理的简化模型开始,此模型忽略了隐状态的特性及其更新方式的细节,且其数学表示没有明确地区分标量、向量和矩阵。在这个简化模型中,我们将时间步 t t t的隐状态表示为 h t h_t ht,输入表示为 x t x_t xt,输出表示为 o t o_t ot,分别使用 w h w_h wh和 w o w_o wo来表示隐藏层和输出层的权重。每个时间步的隐状态和输出可以写为:
h t = f ( x t , h t − 1 , w h ) , o t = g ( h t , w o ) , (2) \begin{aligned}h_t &= f(x_t, h_{t-1}, w_h),\\o_t &= g(h_t, w_o),\end{aligned}\tag{2} htot=f(xt,ht−1,wh),=g(ht,wo),(2)
其中 f f f和 g g g分别是隐藏层和输出层的变换。因此,我们有一个链 { … , ( x t − 1 , h t − 1 , o t − 1 ) , ( x t , h t , o t ) , … } \{\ldots, (x_{t-1}, h_{t-1}, o_{t-1}), (x_{t}, h_{t}, o_t), \ldots\} {…,(xt−1,ht−1,ot−1),(xt,ht,ot),…},它们通过循环计算彼此依赖。前向传播相当简单,一次一个时间步的遍历三元组 ( x t , h t , o t ) (x_t, h_t, o_t) (xt,ht,ot),然后通过一个目标函数在所有 T T T个时间步内评估输出 o t o_t ot和对应的标签 y t y_t yt之间的差异:
L ( x 1 , … , x T , y 1 , … , y T , w h , w o ) = 1 T ∑ t = 1 T l ( y t , o t ) . L(x_1, \ldots, x_T, y_1, \ldots, y_T, w_h, w_o) = \frac{1}{T}\sum_{t=1}^T l(y_t, o_t). L(x1,…,xT,y1,…,yT,wh,wo)=T1t=1∑Tl(yt,ot).
对于反向传播,按照链式法则:
∂ L ∂ w h = 1 T ∑ t = 1 T ∂ l ( y t , o t ) ∂ w h = 1 T ∑ t = 1 T ∂ l ( y t , o t ) ∂ o t ∂ g ( h t , w o ) ∂ h t ∂ h t ∂ w h . \begin{aligned}\frac{\partial L}{\partial w_h} & = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial w_h} \\& = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial o_t} \frac{\partial g(h_t, w_o)}{\partial h_t} \frac{\partial h_t}{\partial w_h}.\end{aligned} ∂wh∂L=T1t=1∑T∂wh∂l(yt,ot)=T1t=1∑T∂ot∂l(yt,ot)∂ht∂g(ht,wo)∂wh∂ht.
在上式乘积的第一项和第二项很容易计算,而第三项比较棘手,因为我们需要循环地计算参数 w h w_h wh对 h t h_t ht的影响。根据式(2), h t h_t ht既依赖于 h t − 1 h_{t-1} ht−1又依赖于 w h w_h wh,其中 h t − 1 h_{t-1} ht−1的计算也依赖于 w h w_h wh。因此,使用链式法则产生:
∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h . (3) \frac{\partial h_t}{\partial w_h}= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}.\tag{3} ∂wh∂ht=∂wh∂f(xt,ht−1,wh)+∂ht−1∂f(xt,ht−1,wh)∂wh∂ht−1.(3)
为了导出上述梯度,假设我们有三个序列 { a t } , { b t } , { c t } \{a_{t}\},\{b_{t}\},\{c_{t}\} {at},{bt},{ct},当 t = 1 , 2 , … t=1,2,\ldots t=1,2,…时,序列满足 a 0 = 0 a_{0}=0 a0=0且 a t = b t + c t a t − 1 a_{t}=b_{t}+c_{t}a_{t-1} at=bt+ctat−1。对于 t ≥ 1 t\geq 1 t≥1,就很容易得出:
a t = b t + ∑ i = 1 t − 1 ( ∏ j = i + 1 t c j ) b i . (4) a_{t}=b_{t}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}c_{j}\right)b_{i}.\tag{4} at=bt+i=1∑t−1(j=i+1∏tcj)bi.(4)
基于下列公式替换 a t a_t at、 b t b_t bt和 c t c_t ct:
a t = ∂ h t ∂ w h , b t = ∂ f ( x t , h t − 1 , w h ) ∂ w h , c t = ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 , \begin{aligned}a_t &= \frac{\partial h_t}{\partial w_h},\\ b_t &= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}, \\ c_t &= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}},\end{aligned} atbtct=∂wh∂ht,=∂wh∂f(xt,ht−1,wh),=∂ht−1∂f(xt,ht−1,wh),
则:
∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∑ i = 1 t − 1 ( ∏ j = i + 1 t ∂ f ( x j , h j − 1 , w h ) ∂ h j − 1 ) ∂ f ( x i , h i − 1 , w h ) ∂ w h . (5) \frac{\partial h_t}{\partial w_h}=\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t} \frac{\partial f(x_{j},h_{j-1},w_h)}{\partial h_{j-1}} \right) \frac{\partial f(x_{i},h_{i-1},w_h)}{\partial w_h}.\tag{5} ∂wh∂ht=∂wh∂f(xt,ht−1,wh)+i=1∑t−1(j=i+1∏t∂hj−1∂f(xj,hj−1,wh))∂wh∂f(xi,hi−1,wh).(5)
虽然我们可以使用链式法则递归地计算 ∂ h t / ∂ w h \partial h_t/\partial w_h ∂ht/∂wh,但当 t t t很大时这个链就会变得很长,在实践中是不可取的。
8.7.1.1 Cutting Off Time Steps
我们也可以在 τ \tau τ步后截断式(5)中的求和计算,即将求和终止为 ∂ h t − τ / ∂ w h \partial h_{t-\tau}/\partial w_h ∂ht−τ/∂wh,这种截断是通过在给定数量的时间步之后分离梯度来实现的。这样做导致该模型主要侧重于短期影响,而不是长期影响,在现实中是可取的。
8.7.1.2 Randomly Truncating
我们也可以用一个随机变量替换 ∂ h t / ∂ w h \partial h_t/\partial w_h ∂ht/∂wh,这个随机变量通过序列 ξ t \xi_t ξt实现。序列预定义了 0 ≤ π t ≤ 1 0 \leq \pi_t \leq 1 0≤πt≤1,其中 P ( ξ t = 0 ) = 1 − π t P(\xi_t = 0) = 1-\pi_t P(ξt=0)=1−πt且 P ( ξ t = π t − 1 ) = π t P(\xi_t = \pi_t^{-1}) = \pi_t P(ξt=πt−1)=πt,因此 E [ ξ t ] = 1 E[\xi_t] = 1 E[ξt]=1。使用 z t z_t zt来替换式(3)中的梯度 ∂ h t / ∂ w h \partial h_t/\partial w_h ∂ht/∂wh得到:
z t = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ξ t ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h . z_t= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\xi_t \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}. zt=∂wh∂f(xt,ht−1,wh)+ξt∂ht−1∂f(xt,ht−1,wh)∂wh∂ht−1.
从 ξ t \xi_t ξt的定义中推导出来 E [ z t ] = ∂ h t / ∂ w h E[z_t] = \partial h_t/\partial w_h E[zt]=∂ht/∂wh,当 ξ t = 0 \xi_t = 0 ξt=0时,递归计算终止在这个 t t t时间步。这导致了不同长度序列的加权和,其中长序列出现的很少,所以需要适当地加大权重。
上图说明了当基于循环神经网络使用通过时间反向传播分析数据集的三种策略:
- 第一行采用随机截断,方法是将文本划分为不同长度的片断;
- 第二行采用常规截断,方法是将文本分解为相同长度的子序列;
- 第三行采用通过时间的完全反向传播,结果是产生了在计算上不可行的表达式。
虽然随机截断在理论上具有吸引力,但由于多种因素在实践中并不总比常规截断更好。首先,在对过去若干个时间步经过反向传播后,观测结果足以捕获实际的依赖关系。其次,增加的方差抵消了时间步数越多梯度越精确的事实。第三,模型可能需要经过一定程度的正则化,以防止过拟合。通过常规截断方法,时间反向传播会引入一定程度的正则化效果,有助于控制模型的复杂度,并提高其泛化能力。
8.7.2 Details of BPTT
下面将展示如何计算目标函数相对于所有模型参数的梯度。简单起见,我们考虑一个没有偏置参数的RNN,其在隐藏层中的激活函数使用恒等映射( ϕ ( x ) = x \phi(x)=x ϕ(x)=x)。对于时间步 t t t,设单个样本的输入及其对应的标签分别为 x t ∈ R d \mathbf{x}_t \in \mathbb{R}^d xt∈Rd和 y t y_t yt。计算隐状态 h t ∈ R h \mathbf{h}_t \in \mathbb{R}^h ht∈Rh和输出 o t ∈ R q \mathbf{o}_t \in \mathbb{R}^q ot∈Rq的方式为:
h t = W h x x t + W h h h t − 1 , o t = W q h h t , \begin{aligned}\mathbf{h}_t &= \mathbf{W}_{hx} \mathbf{x}_t + \mathbf{W}_{hh} \mathbf{h}_{t-1},\\ \mathbf{o}_t &= \mathbf{W}_{qh} \mathbf{h}_{t},\end{aligned} htot=Whxxt+Whhht−1,=Wqhht,
用 l ( o t , y t ) l(\mathbf{o}_t, y_t) l(ot,yt)表示时间步 t t t处的损失函数,则目标函数的总体损失是:
L = 1 T ∑ t = 1 T l ( o t , y t ) . L = \frac{1}{T} \sum_{t=1}^T l(\mathbf{o}_t, y_t). L=T1t=1∑Tl(ot,yt).
模型绘制一个计算图如下所示。
上图中的模型参数是 W h x \mathbf{W}_{hx} Whx、 W h h \mathbf{W}_{hh} Whh和 W q h \mathbf{W}_{qh} Wqh。通常,训练该模型需要分别计算: ∂ L / ∂ W h x \partial L/\partial \mathbf{W}_{hx} ∂L/∂Whx、 ∂ L / ∂ W h h \partial L/\partial \mathbf{W}_{hh} ∂L/∂Whh和 ∂ L / ∂ W q h \partial L/\partial \mathbf{W}_{qh} ∂L/∂Wqh。根据上图中的依赖关系,我们可以沿箭头的相反方向遍历计算图,依次计算和存储梯度。为了灵活地表示链式法则中不同形状的矩阵、向量和标量的乘法,我们继续使用4.7中所述的 prod \text{prod} prod运算符。
首先有:
∂ L ∂ o t = ∂ l ( o t , y t ) T ⋅ ∂ o t ∈ R q . (6) \frac{\partial L}{\partial \mathbf{o}_t} = \frac{\partial l (\mathbf{o}_t, y_t)}{T \cdot \partial \mathbf{o}_t} \in \mathbb{R}^q.\tag{6} ∂ot∂L=T⋅∂ot∂l(ot,yt)∈Rq.(6)
接着得到:
∂ L ∂ W q h = ∑ t = 1 T prod ( ∂ L ∂ o t , ∂ o t ∂ W q h ) = ∑ t = 1 T ∂ L ∂ o t h t ⊤ ∈ R q × h \frac{\partial L}{\partial \mathbf{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{o}_t} \mathbf{h}_t^\top\in \mathbb{R}^{q \times h} ∂Wqh∂L=t=1∑Tprod(∂ot∂L,∂Wqh∂ot)=t=1∑T∂ot∂Lht⊤∈Rq×h
其中 ∂ L / ∂ o t \partial L/\partial \mathbf{o}_t ∂L/∂ot是由式(6)给出的。
接下来,如上图所示,在最后的时间步 T T T,目标函数 L L L仅通过 o T \mathbf{o}_T oT依赖于隐状态 h T \mathbf{h}_T hT。因此,我们通过使用链式法可以很容易地得到梯度$\partial L/\partial \mathbf{h}_T :
∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W q h ⊤ ∂ L ∂ o T ∈ R h . (7) \frac{\partial L}{\partial \mathbf{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_T}, \frac{\partial \mathbf{o}_T}{\partial \mathbf{h}_T} \right) = \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_T}\in \mathbb{R}^h.\tag{7} ∂hT∂L=prod(∂oT∂L,∂hT∂oT)=Wqh⊤∂oT∂L∈Rh.(7)
隐状态的梯度 ∂ L / ∂ h t ∈ R h \partial L/\partial \mathbf{h}_t \in \mathbb{R}^h ∂L/∂ht∈Rh在任何 t < T t < T t<T时都可以递归地计算为:
∂ L ∂ h t = prod ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) + prod ( ∂ L ∂ o t , ∂ o t ∂ h t ) = W h h ⊤ ∂ L ∂ h t + 1 + W q h ⊤ ∂ L ∂ o t . (8) \frac{\partial L}{\partial \mathbf{h}_t} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_{t+1}}, \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t} \right) + \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{h}_t} \right) = \mathbf{W}_{hh}^\top \frac{\partial L}{\partial \mathbf{h}_{t+1}} + \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_t}.\tag{8} ∂ht∂L=prod(∂ht+1∂L,∂ht∂ht+1)+prod(∂ot∂L,∂ht∂ot)=Whh⊤∂ht+1∂L+Wqh⊤∂ot∂L.(8)
对于任何时间步 1 ≤ t ≤ T 1 \leq t \leq T 1≤t≤T展开递归计算得:
∂ L ∂ h t = ∑ i = t T ( W h h ⊤ ) T − i W q h ⊤ ∂ L ∂ o T + t − i . (9) \frac{\partial L}{\partial \mathbf{h}_t}= \sum_{i=t}^T {\left(\mathbf{W}_{hh}^\top\right)}^{T-i} \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_{T+t-i}}.\tag{9} ∂ht∂L=i=t∑T(Whh⊤)T−iWqh⊤∂oT+t−i∂L.(9)
我们可以从式(9)中看到,这个简单的线性例子已经陷入到 W h h ⊤ \mathbf{W}_{hh}^\top Whh⊤的潜在的非常大的幂。在这个幂中,小于1的特征值将会消失,大于1的特征值将会发散。这在数值上是不稳定的,表现形式为梯度消失或梯度爆炸,解决此问题的一种方法如8.7.1中所述。
最后,应用链式规则得:
∂ L ∂ W h x = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h x ) = ∑ t = 1 T ∂ L ∂ h t x t ⊤ ∈ R h × d , ∂ L ∂ W h h = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h h ) = ∑ t = 1 T ∂ L ∂ h t h t − 1 ⊤ ∈ R h × d , \begin{aligned} \frac{\partial L}{\partial \mathbf{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{x}_t^\top\in \mathbb{R}^{h \times d},\\ \frac{\partial L}{\partial \mathbf{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{h}_{t-1}^\top\in \mathbb{R}^{h \times d}, \end{aligned} ∂Whx∂L∂Whh∂L=t=1∑Tprod(∂ht∂L,∂Whx∂ht)=t=1∑T∂ht∂Lxt⊤∈Rh×d,=t=1∑Tprod(∂ht∂L,∂Whh∂ht)=t=1∑T∂ht∂Lht−1⊤∈Rh×d,文章来源:https://www.toymoban.com/news/detail-828077.html
其中 ∂ L / ∂ h t \partial L/\partial \mathbf{h}_t ∂L/∂ht由式(7)和式(8)递归计算得到,是影响数值稳定性的关键量。在训练过程中一些中间值会被存储,以避免重复计算,例如存储 ∂ L / ∂ h t \partial L/\partial \mathbf{h}_t ∂L/∂ht,以便在计算 ∂ L / ∂ W h x \partial L / \partial \mathbf{W}_{hx} ∂L/∂Whx和 ∂ L / ∂ W h h \partial L / \partial \mathbf{W}_{hh} ∂L/∂Whh时使用。文章来源地址https://www.toymoban.com/news/detail-828077.html
到了这里,关于《动手学深度学习(PyTorch版)》笔记8.7的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!