Gradient descent learns linear dynamical systems
Crossposted at offconvex.org.
From text translation to video captioning, learning to map one sequence to another is an increasingly active research area in machine learning. Fueled by the success of recurrent neural networks in its many variants, the field has seen rapid advances over the last few years. Recurrent neural networks are typically trained using some form of stochastic gradient descent combined with backpropagation for computing derivatives. The fact that gradient descent finds a useful set of parameters is by no means obvious. The training objective is typically nonconvex. The fact that the model is allowed to maintain state is an additional obstacle that makes training of recurrent neural networks challenging.
In this post, we take a step back to reflect on the mathematics of recurrent neural networks. Interpreting recurrent neural networks as dynamical systems, we will show that stochastic gradient descent successfully learns the parameters of an unknown linear dynamical system even though the training objective is nonconvex. Along the way, we’ll discuss several useful concepts from control theory, a field that has studied linear dynamical systems for decades. Investigating stochastic gradient descent for learning linear dynamical systems not only bears out interesting connections between machine learning and control theory, it might also provide a useful stepping stone for a deeper undestanding of recurrent neural networks more broadly.
Linear dynamical systems
We focus on timeinvariant singleinput singleoutput system. For an input sequence of real numbers $x_1,\dots, x_T\in \mathbb{R}$, the system maintains a sequence of hidden states $h_1,\dots, h_T\in \mathbb{R}^n$, and produces a sequence of outputs $y_1,\dots, y_T\in \mathbb{R}$ according to the following rules:
Here $A,B,C,D$ are linear transformations with compatible dimensions, and $\xi_t$ is Gaussian noise added to the output at each time. In the learning problem, often called system identification in control theory, we observe samples of inputoutput pairs $((x_1,\dots, x_T),(y_1,\dots y_T))$ and aim to recover the parameters of the underlying linear system.
Although control theory provides a rich set of techniques for identifying and manipulating linear systems, maximum likelihood estimation with stochastic gradient descent remains a popular heuristic.
We denote by $\Theta = (A,B,C,D)$ the parameters of the true system. We parametrize our model with $\widehat{\Theta} = (\hat{A},\hat{B},\hat{C},\hat{D})$, and the trained model maintains hidden states $\hat{h}_t$ and outputs $\hat{y}_t$ exactly as in equation (1). For each given example $(x,y) = ((x_1,\dots,x_T), (y_1,\dots, y_t))$, the loglikelihood of model $\widehat{\Theta}$ is . The population risk is defined as the expected loglikelihood,
Stochastic gradients of the population risk can be computed in time $O(Tn)$ via backpropagation given random samples. We can therefore directly minimize population risk using stochastic gradient descent. The question is just whether the algorithm actually converges. Even though the state transformations are linear, the objective function we defined is not convex. Luckily, we will see that the objective is still close enough to convex for stochastic gradient to make steady progress towards the global minimum.
Hair dryers and quasiconvex functions
Before we go into the math, let’s illustrate the algorithm with a pressing example that we all run into every morning: hair drying. Imagine you have a hair dryer with a low temperature setting and a high temperature setting. Neither setting is ideal. So every morning you switch between the settings frantically in an attempt to modulate to the ideal temperature. Measuring the resulting temperature (red line below) as a function of the input setting (green dots below), the picture you’ll see is something like this:
You can see that the output temperature is related to the inputs. If you set the temperature to high for long enough, you’ll eventually get a high output temperature. But the system has state. Briefly lowering the temperature has little effect on the outputs. Intuition suggests that these kind of effects should be captured by a system with two or three hidden states. So, let’s see how SGD would go about finding the parameters of the system. We’ll initialize a system with three hidden states such that before training its predictions are just the inputs of the system. We then run SGD with a fixed learning rate on the same sequence for 400 steps.
The blue line shows the predictions of SGD after 0/400 gradient updates. Click to advance.
Evidently, gradient descent converges just fine on this example. Let’s look at the hair dryer objective function along the line segment between two random points in the domain.
The function is clearly not convex, but it doesn’t look too bad either. In particular, from the picture, it could be that the objective function is quasiconvex:
Definition: For $\tau > 0$, a function $f(\theta)$ is $\tau$quasiconvex with respect to a global minimum $\theta ^ * $ if for every $\theta$,
Intuitively, quasiconvexity states that the descent direction $\nabla f(\theta)$ is positively correlated with the ideal moving direction $\theta^* \theta$. This implies that the potential function $\left\theta\theta ^ * \right^2$ decreases in expectation at each step of stochastic gradient descent. This observation plugs nicely into the standard SGD analysis, leading to the following result:
Proposition: (informal) Suppose the population risk $f(\theta)$ is $\tau$quasiconvex, then stochastic gradient descent (with fresh samples at each iteration and proper learning rate) converges to a point $\theta_K$ in $K$ iterations with error bounded by $ f(\theta_K)  f(\theta^*) \leq O(1/(\tau \sqrt{K}))$.
The key challenge for us is to understand under what conditions we can prove that the population risk objective is in fact quasiconvex. This requires some background.
Control theory, polynomial roots, and PacMan
A linear dynamical system $(A,B,C,D)$ is equivalent to the system $(TAT^{1}, TB, CT^{1}, D)$ for any invertible matrix $T$ in terms of the behavior of the outputs. A little thought shows therefore that in its unrestricted parameterization the objective function cannot have a unique optimum. A common way of removing this redundancy is to impose a canonical form. Almost all nondegenerate system admit the controllable canonical form, defined as
We will also parametrize our training model using these forms. One of its nice properties is that the coefficients of the characteristic polynomial of the state transition matrix $A$ can be read off from the last row of $A$. That is,
Even in controllable canonical form, it still seems rather difficult to learn arbitrary linear dynamical systems. A natural restriction would be stability, that is, to require that the eigenvalues of $A$ are all bounded by $1.$ Equivalently, the roots of the characteristic polynomial should all be contained in the complex unit disc. Without stability, the state of the system could blow up exponentially making robust learning difficult. But the set of all stable systems forms a nonconvex domain. It seems daunting to guarantee that stochastic gradient descent would converge from an arbtirary starting point in this domain without ever leaving the domain.
We will therefore impose a stronger restriction on the roots of the characteristic polynomial. We call this the PacMan condition. You can think of it as a strengthening of stability.
PacMan condition: A linear dynamical system in controllable canonical form satisfies the PacMan condition if the coefficient vector $a$ defining the state transition matrix satisfies for all complex numbers $z$ of modulus $z = 1$, where $q_a(z) = p_a(z)/z^n = 1+a_1z^{1}+\dots + a_nz^{n}$.
Above, we illustrate this condition for a degree 4 system plotting the value of $q_a(z)$ on complex plane for all complex numbers $z$ on the unit circle.
We note that PacMan condition is satisfied by vectors $a$ with $a_1\le \sqrt{2}/2$. Moreover, if $a$ is a random Gaussian vector with expected $\ell_2$ norm bounded by $o(1/\sqrt{\log n})$, then it will satisfy PacMan condition with probability $1o(1)$. Roughly speaking, the assumption requires the roots of the characteristic polynomial $p_a(z)$ are relatively dispersed inside the unit circle.
The PacMan condition has three important implications:

It implies via Rouche’s theorem that the spectral radius of A is smaller than 1 and therefore ensures stability of the system.

The vectors satisfying it form a convex set in $\mathbb{R}^n$.

Finally, it ensures that the objective function is quasiconvex
Main result
Relying on the PacMan condition, we can show:
Main theorem (Hardt, Ma, Recht, 2016): Under the PacMan condition, projected gradient descent algorithm, given $N$ sample sequences of length $T$, returns parameters $\widehat{\Theta}$ with population risk
The theorem sorts out the right dependence on $N$ and $T$. Even if there is only one sequence, we can learn the system provided that the sequence is long enough. Similarly, even if sequences are really short, we can learn provided that there are enough sequences.
Quasiconvexity in the frequency domain
To establish quasiconvexity under the PacMan condition, we will first develop an explicit formula for the population risk in frequency domain. In doing so, we assume that $x_1,\dots, x_T$ are pairwise independent with mean 0 and variance 1. We also consider the population risk as $T\rightarrow \infty$ for simplicity in this post.
A simple algebraic manipulation simplifies the population risk with infinite sequence length to
The first term, $(\hat D  D)^2$ is convex and appears nowhere else. We can safely ignore it and focus on the remaining expression instead, which we call the idealized risk:
To deal with the sequence $\hat{C}\hat{A}^kB$, we take its Fourier transform and obtain that
Similarly we take the Fourier transform of $CA^kB$, denoted by $G_{\lambda}$. Then by Parseval’s Theorem, we obtain the following alternative representation of the population risk,
Mapping out $G_\lambda$ and $\widehat G_\lambda$ for all $\lambda\in [0, 2\pi]$ gives the following picture:
Left: Target transfer function $G$. Right: Approximation $\widehat G$ at step 0/10. Click to advance.
Given this pretty representation of the idealized risk objective, we can finally prove our main lemma.
Lemma: Suppose $\Theta$ satisfies the PacMan condition. Then, for every $0\le \lambda\le 2\pi$, $G_{\lambda}\widehat{G}_{\lambda}^2$, as a function of $\hat{A},\hat{C}$ is quasiconvex in the PacMan region.
The lemma reduces to the following simple claim.
Claim: The function $h(\hat{u},\hat{v}) = \hat{u}/\hat{v}  u/v^2$ is quasiconvex in the region where $Re(\hat{v}/v) > 0$.
The proof simply involves computing the gradients and checking the conditions for quasiconvexity by elementary algebra. We omit a formal proof, but intead show a plot of the function $h(\hat{u}, \hat{v}) = (\hat{u}/\hat{v} 1)^2$ over the reals:
Click to rotate.
To see how the lemma follows from the previous claim we note that quasiconvexity is preserved under composition with any linear transformation. Specifically, $h(z)$ is quasiconvex, then $h(R x)$ is also quasiconvex for any linear map $R$. So, consider the linear map:
With this linear transformation, our simple claim about a bivariate function extends to show that $(G_{\lambda}\widehat{G}_{\lambda})^2$ is quasiconvex when $Re(\hat{v}/v) \ge 0$. In particular, when $\hat{a}$ and $a$ both satisfy the PacMan condition, then $\hat{v}$ and $v$ both reside in the 90 degree wedge. Therefore they have an angle smaller than 90 degree. This implies that $Re(\hat{v}/v) > 0$.
Conclusion
We saw conditions under which stochastic gradient descent successfully learns a linear dynamical system. In our paper, we further show that allowing our learned system to have more parameters than the target system makes the problem dramatically easier. In particular, at the expense of slight overparameterization we can weaken the PacMan condition to a mild separation condition on the roots of the characteristic polynomial. This is consistent with empirical observations both in machine learning and control theory that highlight the effectiveness of additional model parameters.
More broadly, we hope that our techniques will be a first stepping stone toward a better theoretical understanding of recurrent neural networks.