Deriving the backpropagation equations for a LSTM

In this post I will derive the backpropagation equations for a LSTM cell in vectorised form. It assumes basic knowledge of LSTMs and backpropagation, which you can refresh at Understanding LSTM Networks and A Quick Introduction to Backpropagation.


Forward propagation

We will firstly remind ouselves of the forward propagation equations. The nomenclature followed is demonstrated in Figure 1. All equations correspond to one time step.

Figure 1: Architecture of a LSTM memory cell at timestep t
h_{t-1} \in  \mathbb{R}^{n_{h}},  \mspace{31mu} x_{t} \in  \mathbb{R}^{n_{x}} 
z_{t}= [h_{t-1}, x_{t}]
a_{f}= W_{f}\cdot z_{t} + b_{f}, \mspace{31mu}  f_{t}= \sigma(a_{f})
a_{i}= W_{i}\cdot z_{t} + b_{i}, \mspace{40mu}  i_{t}= \sigma(a_{i})
a_{o}= W_{o}\cdot z_{t} + b_{o}, \mspace{34mu}  o_{t}= \sigma(a_{o})
a_{c}= W_{c}\cdot z_{t} + b_{c}, \mspace{36mu}  \hat{c}_t=  tanh(a_{c})

{c}_t=  i_{t}\odot \hat{c}_t + f_{t}\odot c_{t-1}
{h}_t=  o_{t}\odot tanh(c_{t})

v_{t}= W_{v}\cdot h_{t} + b_{v}
\hat{y}_t= softmax(v_{t})

Backward propagation

Backpropagation through a LSTM is not as straightforward as through other common Deep Learning architectures, due to the special way its underlying layers interact. Nonetheless, the approach is largely the same; identifying dependencies and recursively applying the chain rule.

Figure 2: Backpropagation through a LSTM memory cell

Cross-entropy loss with a softmax function are used at the output layer. The standard definition of the derivative of the cross-entropy loss (\frac{\partial J}{\partial v_{t}}) is used directly; a detailed derivation can be found here.

\frac{\partial J}{\partial v_{t}} = \hat{y}_t - y_{t}

\frac{\partial J}{\partial W_{v}} = \frac{\partial J}{\partial v_{t}} \cdot \frac{\partial v_{t}}{\partial W_{v}} \Rightarrow \frac{\partial J}{\partial W_{v}} = \frac{\partial J}{\partial v_{t}} \cdot h_{t}^T

\frac{\partial J}{\partial b_{v}} = \frac{\partial J}{\partial v_{t}} \cdot \frac{\partial v_{t}}{\partial b_{v}} \Rightarrow \frac{\partial J}{\partial b_{v}} = \frac{\partial J}{\partial v_{t}}

hidden state
\frac{\partial J}{\partial h_{t}} = \frac{\partial J}{\partial v_{t}} \cdot \frac{\partial v_{t}}{\partial h_{t}} \Rightarrow \frac{\partial J}{\partial h_{t}} = W_{v}^T \cdot \frac{\partial J}{\partial v_{t}}

\frac{\partial J}{\partial h_{t}} += \frac{\partial J}{\partial h_{next}}
output gate
\frac{\partial J}{\partial o_{t}} = \frac{\partial J}{\partial h_{t}} \cdot \frac{\partial h_{t}}{\partial o_{t}} \Rightarrow \frac{\partial J}{\partial o_{t}} = \frac{\partial J}{\partial h_{t}} \odot tanh(c_{t})

\frac{\partial J}{\partial a_{o}} = \frac{\partial J}{\partial o_{t}} \cdot \frac{\partial o_{t}}{\partial a_{o}} \Rightarrow \frac{\partial J}{\partial a_{o}} = \frac{\partial J}{\partial h_{t}} \odot tanh(c_{t}) \odot \frac{d(\sigma (a_{o}))}{da_{o}}  \newline \Rightarrow \frac{\partial J}{\partial a_{o}} = \frac{\partial J}{\partial h_{t}} \odot tanh(c_{t}) \odot \sigma (a_{o})(1- \sigma (a_{o}))  \newline \Rightarrow \frac{\partial J}{\partial a_{o}} = \frac{\partial J}{\partial h_{t}} \odot tanh(c_{t}) \odot o_{t}(1- o_{t})

\frac{\partial J}{\partial W_{o}} = \frac{\partial J}{\partial a_{o}} \cdot \frac{\partial a_{o}}{\partial W_{o}}  \Rightarrow \frac{\partial J}{\partial W_{o}} = \frac{\partial J}{\partial a_{o}} \cdot z_{t}^T

\frac{\partial J}{\partial b_{o}} = \frac{\partial J}{\partial a_{o}} \cdot \frac{\partial a_{o}}{\partial b_{o}} \Rightarrow \frac{\partial J}{\partial b_{o}} = \frac{\partial J}{\partial a_{o}}
cell state
\frac{\partial J}{\partial c_{t}} = \frac{\partial J}{\partial h_{t}} \cdot \frac{\partial h_{t}}{\partial c_{t}} \Rightarrow \frac{\partial J}{\partial c_{t}} = \frac{\partial J}{\partial h_{t}} \odot o_{t} \odot (1-tanh(c_{t})^2)

\frac{\partial J}{\partial c_{t}} += \frac{\partial J}{\partial c_{next}}
\frac{\partial J}{\partial \hat{c}_t} = \frac{\partial J}{\partial c_{t}} \cdot \frac{\partial c_{t}}{\partial \hat{c}_t} \Rightarrow \frac{\partial J}{\partial \hat{c}_t} = \frac{\partial J}{\partial c_{t}} \odot i_{t}

\frac{\partial J}{\partial a_{c}} = \frac{\partial J}{\partial \hat{c}_t} \cdot \frac{\partial \hat{c}_t}{\partial a_{c}} \Rightarrow \frac{\partial J}{\partial a_{c}} = \frac{\partial J}{\partial c_{t}} \odot i_{t} \odot \frac{d(tanh(a_{c}))}{da_{c}} \newline \Rightarrow \frac{\partial J}{\partial a_{c}} = \frac{\partial J}{\partial c_{t}} \odot i_{t} \odot (1 - tanh(a_{c})^2) \newline \Rightarrow \frac{\partial J}{\partial a_{c}} = \frac{\partial J}{\partial c_{t}} \odot i_{t} \odot (1 - \hat{c}_t^2)
\frac{\partial J}{\partial W_{c}} = \frac{\partial J}{\partial a_{c}} \cdot \frac{\partial a_{c}}{\partial W_{c}} \Rightarrow \frac{\partial J}{\partial W_{c}} = \frac{\partial J}{\partial a_{c}} \cdot z_{t}^T

\frac{\partial J}{\partial b_{c}} = \frac{\partial J}{\partial a_{c}} \cdot \frac{\partial a_{c}}{\partial b_{c}} \Rightarrow \frac{\partial J}{\partial b_{c}} = \frac{\partial J}{\partial a_{c}}
input gate
\frac{\partial J}{\partial i_{t}} = \frac{\partial J}{\partial c_{t}} \cdot \frac{\partial c_{t}}{\partial i_{t}} \Rightarrow \frac{\partial J}{\partial i_{t}} = \frac{\partial J}{\partial c_{t}} \odot \hat{c}_t 

\frac{\partial J}{\partial a_{i}} = \frac{\partial J}{\partial i_{t}} \cdot \frac{\partial i_{t}}{\partial a_{i}} \Rightarrow \frac{\partial J}{\partial a_{i}} = \frac{\partial J}{\partial c_{t}} \odot \hat{c}_t \odot \frac{d(\sigma (a_{i}))}{da_{i}} \newline \Rightarrow \frac{\partial J}{\partial a_{i}} = \frac{\partial J}{\partial c_{t}} \odot \hat{c}_t \odot \sigma (a_{i})(1- \sigma (a_{i})) \newline \Rightarrow \frac{\partial J}{\partial a_{i}} = \frac{\partial J}{\partial c_{t}} \odot \hat{c}_t \odot i_{t}(1- i_{t})

\frac{\partial J}{\partial W_{i}} = \frac{\partial J}{\partial a_{i}} \cdot \frac{\partial a_{i}}{\partial W_{i}}  \Rightarrow \frac{\partial J}{\partial W_{i}} = \frac{\partial J}{\partial a_{i}} \cdot z_{t}^T

\frac{\partial J}{\partial b_{i}} = \frac{\partial J}{\partial a_{i}} \cdot \frac{\partial a_{i}}{\partial b_{i}} \Rightarrow \frac{\partial J}{\partial b_{i}} = \frac{\partial J}{\partial a_{i}}
forget gate
\frac{\partial J}{\partial f_{t}} = \frac{\partial J}{\partial c_{t}} \cdot \frac{\partial c_{t}}{\partial f_{t}} \Rightarrow \frac{\partial J}{\partial f_{t}} = \frac{\partial J}{\partial c_{t}} \odot c_{t-1}

\frac{\partial J}{\partial a_{f}} = \frac{\partial J}{\partial f_{t}} \cdot \frac{\partial f_{t}}{\partial a_{f}} \Rightarrow \frac{\partial J}{\partial a_{f}} = \frac{\partial J}{\partial c_{t}} \odot c_{t-1} \odot \frac{d(\sigma (a_{f}))}{da_{f}} \newline \Rightarrow \frac{\partial J}{\partial a_{f}} = \frac{\partial J}{\partial c_{t}} \odot c_{t-1} \odot \sigma (a_{f})(1- \sigma (a_{f}) \newline \Rightarrow \frac{\partial J}{\partial a_{f}} = \frac{\partial J}{\partial c_{t}} \odot c_{t-1} \odot f_{t}(1- f_{t})

\frac{\partial J}{\partial W_{f}} = \frac{\partial J}{\partial a_{f}} \cdot \frac{\partial a_{f}}{\partial W_{f}}  \Rightarrow \frac{\partial J}{\partial W_{f}} = \frac{\partial J}{\partial a_{f}} \cdot z_{t}^T

\frac{\partial J}{\partial b_{f}} = \frac{\partial J}{\partial a_{f}} \cdot \frac{\partial a_{f}}{\partial b_{f}} \Rightarrow \frac{\partial J}{\partial b_{f}} = \frac{\partial J}{\partial a_{f}}
\frac{\partial J}{\partial z_{t}} = \frac{\partial J}{\partial a_{f}} \cdot \frac{\partial a_{f}}{\partial z_{t}} + \frac{\partial J}{\partial a_{i}} \cdot \frac{\partial a_{i}}{\partial z_{t}} + \frac{\partial J}{\partial a_{o}} \cdot \frac{\partial a_{o}}{\partial z_{t}} + \frac{\partial J}{\partial a_{c}} \cdot \frac{\partial a_{c}}{\partial z_{t}}  \newline \Rightarrow \frac{\partial J}{\partial z_{t}} =  W_{f}^T \cdot \frac{\partial J}{\partial a_{f}} +W_{i}^T \cdot \frac{\partial J}{\partial a_{i}} + W_{o}^T \cdot \frac{\partial J}{\partial a_{o}} + W_{c}^T \cdot \frac{\partial J}{\partial a_{c}}

\frac{\partial J}{\partial h_{t-1}} = \frac{\partial J}{\partial z_{t}}[:n_{h}, :]

\frac{\partial J}{\partial c_{t-1}} = \frac{\partial J}{\partial c_{t}} \cdot \frac{\partial c_{t}}{\partial c_{t-1}} \Rightarrow \frac{\partial J}{\partial c_{t-1}} = \frac{\partial J}{\partial c_{t}} \odot f_{t}

The above equations for forward propagation and back propagation will be calculated T times (number of time steps) in each training iteration. At the end of each training iteration, the weights will be updated using the accumulated cost gradient with respect to each weight for all time steps. Assuming Stochastic Gradient Descent, the update equations are the following:

\frac{\partial J}{\partial W_{f}} = \sum\limits_{t}^T \frac{\partial J}{\partial W_{f}^t}, \mspace{31mu} W_{f} += \alpha * \frac{\partial J}{\partial W_{f}}

\frac{\partial J}{\partial W_{i}} = \sum\limits_{t}^T \frac{\partial J}{\partial W_{i}^t}, \mspace{31mu} W_{i} += \alpha * \frac{\partial J}{\partial W_{i}}

\frac{\partial J}{\partial W_{o}} = \sum\limits_{t}^T \frac{\partial J}{\partial W_{o}^t}, \mspace{31mu} W_{o} += \alpha * \frac{\partial J}{\partial W_{o}}

\frac{\partial J}{\partial W_{c}} = \sum\limits_{t}^T \frac{\partial J}{\partial W_{c}^t}, \mspace{31mu} W_{c} += \alpha * \frac{\partial J}{\partial W_{c}}

\frac{\partial J}{\partial W_{v}} = \sum\limits_{t}^T \frac{\partial J}{\partial W_{v}^t}, \mspace{31mu} W_{v} += \alpha * \frac{\partial J}{\partial W_{v}}

In the next post, we will implement the above equations using Numpy and train the resulting LSTM model on real data.

6 thoughts on “Deriving the backpropagation equations for a LSTM”

    1. Since we are concatenating [h_(t-1), x_t] to get z_t, dJ/dh_(t-1) would be just the first dimension h_(t-1) elements of the dJ/dz_t vector, which is what she wrote in Python slice like notation


  1. Hey, I find your notation dJ/dh_t +- dJ/dh_next a bit confusing, it might be better to rewrite this as:
    dJ/dh_t += dJ/dh_next


Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s