A Paper in Thousand Words: Neural Architecture Search with Reinforcement Learning

Paper is found here.

One of the key advantages of Deep Models is that they made feature engineering obsolete. With this came a paradim-shift; from engineering robust features to engineering deep architectures, i.e. hyperparameters, for machine learning tasks.

This paper uses reinforcement learning (RL) to find the best deep architecture for a given task. This is achieved by using a “controller” a Long-Short Term Memory (LSTM) Network (a variant of Recurrent Neural Network) , to output architectures, e.g. a child network which is a Convolution Neural Network (CNN). Then the child network is evaluated on a dataset (e.g. CIFAR10) and returns the accuracy on a validation dataset which is used as the reward signal in the RL environment. Finally the controller updates its weights in a direction to maximize the expectation of the reward signal. This is achieved based on an algorithm known as “REINFORCE”.

This problem of “Hyperparameter Optimization” is not an out-of-the-world concept. This is quite old and has previously been attempted using techniques such as Genetic Algorithms, Bayesian Optimization, etc. Also, this works shares a close relationship with “Learning to Learn” and “Meta-Learning”.

Let us now sail the high-seas, diving deep into the algorithm!

The Controller

The controller is a LSTM. An LSTM is a flexible learning model that performs well for time-series data (Shallow Explanation). The paper doesn’t define what the initial input is. So I assume it’s a constant such as “ROOT”, that will be the same always. Then let us assume the controller predicts N layers. Then the LSTM predicts a sequence of number of filters, filter height, filter width, stride height and stride width as a sequence where each of these is a single prediction. Each prediction is taken by sending the input through the network (a single column in figure) and getting the output probability using a softmax layer (colored cells in the figure). Prediction at time t will be the input at time step t+1. The controller is illustrated in the below figure. The number on the right upper corner indicates the number of output units in that layer.

The Controller

The concatenated output of the controller is a sequence of tokens that specifices the hyperparameters of each layer. Then the network with the given specifications is designed and trained on the task, e.g. classification with CIFAR10. Next the performance (accuracy) of the generated network is evaluated using a validation set. This accuracy is used as the reward signal R to update the parameters of the controller \theta_c.

However there’s a pitfall here. R is non-differentiable as R is not an output of the controller, but the child network. To circumvent this a special policy gradient technique known as REINFORCE is used. The goal of REINFORCE (Ref) is to provide a mechanism, i.e. equation, that allows us to update \theta_c in a direction that maximizes a non-differentiable reward. So by using REINFORCE and an empirical approximation we arrive at the weight update rule for \theta_c as below.

(1)   \begin{equation*} \nabla_{\theta_c}(E[R]) = \frac{1}{m}\sum_{k=1}^{m}\sum_{t=1}^{T}\nabla_{\theta_c}log P(a_t|a_{t-1:1};\theta_c)R_k \end{equation*}

This basically says that the gradient of the controller is approximated by “the sum of gradient of log of predictions weighted by the reward of each network sampled by the controller, averaged over all the networks sampled in a single batch”. I know it’s a mouthful, so process it slowly! That’s it for the basic algorithm.

Faster Training

It’s no secret that they have to evaluate thousands models with millions of parameters in each which is highly computationally costly. To solve this they use a “parameter-server” scheme (Ref).

Advance Hyperparameters

Here they also introduce ways to incorporate ways to explore more advance hyperparameters such as “skip connections” or “branching layers”. Skip connections play a vital role in very deep networks as they help the flow of the gradient (Ref). By branch layers I assume something similar to the inception module of Google.

To implement this they introduce “Anchor Points” in between each layer hyperparameter set. The anchor point at layer n has n-1 incoming connections to it which says which previous layer will be an input to the n^{th} layer. Each of these connections predicts a sigmoidal output taking weights and hidden states as an input (See the original paper for details)

Extending the controller to design LSTMs

Things are about to get hairier so bare with me and I must warn you this part the paper I have at hand the notation is all over and difficult to follow (probably because it was under review). So I’m going to build my own interpretation of them following my LSTM post.

Designing a LSTM

Designing a LSTM

It’s very hard to give a comprehensive explanation of this due to poorly formulated explanation and the complexity of LSTMs itself. But I will explain it to the extent I understood. Designing a LSTM is not as straight forward as CNNs as they have numerous time-dependent and independent calculations happening. But a LSTM cell can be decoded as a tree structure where it takes three inputs h_{t-1} (previous state), x_t (input) and c_{t-1} (previous memory) and does various transformation to output h_t and c_t as illustrated in the upper-left corner of the second figure.

Now let us see how the controller works for designing LSTMs. First the network outputs 2 outputs, i.e. a way to combine h_{t-1} and x_t and an activation per tree index. This is quite sensible but the rest is unclear. Next there are two outputs called “cell inject” which is used by the last output to calculate what authors call a_0^{new} which then is used to calculate a_2. And I don’t have a good explanation for how or why this part works (dashed red line).

In experiments they claim that their model outperforms and 1.05 x faster than the previous similar state-of-the-art models for CIFAR10. But I wouldn’t say these are ground-breaking results. Finally the LSTM their model designed said to outperform the previous best on Penn Treebank dataset.

Conclusion (My opinion)

I find their approach to solving the problem of hyperparameter optimization interesting. But some parts of the explanation of the algorithm is poorly-structured and lacks certain details (e.g. what is the input to the controller?, Latter part of the controller in Figure 5 is very confusing. Figure 5 (left) is missing the c_t in it, what are \nu^T,W_1,W_2,...). Also the experiments can improve. It would be more convincing to see experiments on more datasets (e.g. CIFAR100, Imagenet). And the very fact they market in the abstract “Our CIFAR-10 model … which is 0.09 percent better and 1.05x faster” is counterargued in the experiments by saying DenseNet model … uses 1×1 convolutions to reduce its parameters, which we did not do, so it is not an exact comparison“. So I have my doubts about this method.

PS: Special thank to Pepe for directing me to this paper.