📜  LSTM –随时间反向传播的推导(1)

📅  最后修改于: 2023-12-03 14:44:04.628000             🧑  作者: Mango

LSTM – 随时间反向传播的推导

引言

LSTM(长短时记忆网络)是一种在深度学习中广泛应用的循环神经网络(RNN)的变体。它通过使用门控机制来对输入数据进行建模,从而更好地捕捉时序信息和长期依赖关系。本文将介绍LSTM的背景知识,并通过推导来详细说明其随时间反向传播算法。

背景知识

在深度学习中,RNN是一种用于处理序列数据的强大工具。然而,传统的RNN在面对长期依赖关系时会遇到梯度消失和梯度爆炸的问题。为了解决这些问题,LSTM引入了三个门控单元:输入门(input gate)、遗忘门(forget gate)和输出门(output gate)。

LSTM结构

LSTM由多个LSTM单元组成,每个LSTM单元包含一个细胞状态(cell state),负责存储和传递信息。每个LSTM单元还由输入门、遗忘门和输出门组成,这些门控制了信息的流动。下面是LSTM单元的结构示意图:

LSTM

LSTM单元结构

LSTM单元的核心操作是在时间步长上进行的,并且可以根据时间步长进行展开(unroll)。在每个时间步长上,LSTM单元接收输入数据、先前的隐藏状态和细胞状态,并计算下一个隐藏状态和细胞状态作为输出。

以下是LSTM单元的数学表达式,用于计算隐藏状态和细胞状态:

输入: 
- xt: 当前时间步长的输入向量
- ht-1: 先前时间步长的隐藏状态
- ct-1: 先前时间步长的细胞状态
参数: 
- Wf, Wi, Wc, Wo: 输入权重
- Uf, Ui, Uc, Uo: 隐藏状态权重
- bf, bi, bc, bo: 偏置

计算:
- ft = sigmoid(Wf * xt + Uf * ht-1 + bf)   # 计算遗忘门
- it = sigmoid(Wi * xt + Ui * ht-1 + bi)   # 计算输入门
- ct_bar = tanh(Wc * xt + Uc * ht-1 + bc)   # 计算细胞候选值
- ct = ft * ct-1 + it * ct_bar             # 更新细胞状态
- ot = sigmoid(Wo * xt + Uo * ht-1 + bo)   # 计算输出门
- ht = ot * tanh(ct)                       # 计算隐藏状态

以上计算是一个LSTM单元的前向传播过程,现在我们来推导LSTM单元的反向传播算法。

LSTM单元反向传播

为了计算LSTM单元的梯度,我们需要根据损失函数对每个参数求偏导数,并使用链式法则沿时间步长进行反向传播。

以下是LSTM单元梯度的计算公式:

输入:
- Δht: 当前时间步长的隐藏状态梯度
- Δct: 当前时间步长的细胞状态梯度
- ht: 当前时间步长的隐藏状态
- ct: 当前时间步长的细胞状态
- ft, it, ct_bar, ot: 当前时间步长的门控值
参数: 
- Wf, Wi, Wc, Wo: 输入权重
- Uf, Ui, Uc, Uo: 隐藏状态权重
- bf, bi, bc, bo: 偏置

计算:
- Δot = tanh(ct) * Δht                              # 计算输出门梯度
- Δct += ot * (1 - tanh(ct) ** 2) * Δht             # 计算细胞状态梯度
- Δft = ct-1 * Δct                                  # 计算遗忘门梯度
- Δct-1 = ft * Δct                                 # 计算上一个细胞状态梯度
- Δit = ct_bar * Δct                                # 计算输入门梯度
- Δct_bar = it * Δct                                # 计算细胞候选值梯度
- Δxt = Wf.T * Δft + Wi.T * Δit + Wc.T * Δct_bar + Wo.T * Δot  # 计算输入向量梯度
- Δht-1 = Uf.T * Δft + Ui.T * Δit + Uc.T * Δct_bar + Uo.T * Δot  # 计算上一个隐藏状态梯度
- ΔWf = Δft * xt.T                                  # 计算输入权重梯度
- ΔWi = Δit * xt.T                                  # 计算输入权重梯度
- ΔWc = Δct_bar * xt.T                              # 计算输入权重梯度
- ΔWo = Δot * xt.T                                  # 计算输入权重梯度
- ΔUf = Δft * ht-1.T                               # 计算隐藏状态权重梯度
- ΔUi = Δit * ht-1.T                               # 计算隐藏状态权重梯度
- ΔUc = Δct_bar * ht-1.T                           # 计算隐藏状态权重梯度
- ΔUo = Δot * ht-1.T                               # 计算隐藏状态权重梯度
- Δbf = np.sum(Δft, axis=1, keepdims=True)           # 计算偏置梯度
- Δbi = np.sum(Δit, axis=1, keepdims=True)           # 计算偏置梯度
- Δbc = np.sum(Δct_bar, axis=1, keepdims=True)       # 计算偏置梯度
- Δbo = np.sum(Δot, axis=1, keepdims=True)           # 计算偏置梯度

更新参数:
- Wf -= learning_rate * ΔWf
- Wi -= learning_rate * ΔWi
- Wc -= learning_rate * ΔWc
- Wo -= learning_rate * ΔWo
- Uf -= learning_rate * ΔUf
- Ui -= learning_rate * ΔUi
- Uc -= learning_rate * ΔUc
- Uo -= learning_rate * ΔUo
- bf -= learning_rate * Δbf
- bi -= learning_rate * Δbi
- bc -= learning_rate * Δbc
- bo -= learning_rate * Δbo
总结

本文对LSTM进行了简要介绍,并详细推导了LSTM单元的随时间反向传播算法。LSTM在处理时序数据和长期依赖关系时表现出色,而推导过程则提供了理解LSTM内部运作的洞察力。希望本文能对程序员们更好地理解和应用LSTM提供帮助。

以上是LSTM – 随时间反向传播的推导内容,希望对你有所帮助!