## RA-DAE: Structurally Adaptive Deep Architecture inspired by Reinforcement Learning

In this post, I’m going to introduce a type of a Stacked Autoencoders (SAE) (Don’t worry if you don’t understand what an SAE is. Will explain later.). And worth a mention, that this is some research work done by me and few colleague from our research lab. So yay for shameless self-advertising. Anyway, I thought this is worth sharing as to best of my knowledge this is first of it’s kind using Reinforcement Learning to adapt the structure of the network. This paper has been published in ECAI 2016 and can be downloaded via this link.

## Introduction

I don’t think I need to stress over the fact that deep learning is at the heart of machine learning and continue to mesmerise the world with it’s unprecedented capabilities. However, no matter how formidable deep learning sound, they are not without weaknesses. For example, ponder what would happen the knowledge stored in deep networks if class distribution of the data keeps changing. Let me illustrate this with an example.

## Covariate-Shift: Garlic for Deep Learning

Though everyone seems to think Deep Learning as “the silver bullet”; they can fail miserably provided adverse conditions. For example, “covariate shift” is not well handled by vanilla deep learning methods currently. So what exactly is covariate shift? Put simply, it is when the data distribution changes over time, so that you don’t get the same amount of instances of each class uniformly over time.

## Ill-Effects of the Covariate Shift

Okey fair enough! But what could be so bad about covariate shift, that it could tarnish the reputation of deep learning. Let’s understand that though an example scenario.

Think of an example where you’re trying to classify images belonging to two classes; digit 1 and 2. If your data is uniformly distributed among classes, no biggy! Your Neural Net will do fine. But if I feed all digit 1 images first and then feed the images of digit 2, your neural net will not do as well as it would with the uniform distribution. This is because during the first half the network will overfit digit 1 and then overfit digit 2, thus not being able to learn a balanced representation of the both digits. This is the ill-effect of the covariate shift. And this is what we’re trying to address in our paper.

## Solution: Let the Neural Network change its Structure

So let’s start weaving this intricate tapestry called RA-DAE that copes with covariate shift. RA-DAE is a neural network that functions like a Stacked Denoising Autoencoder. You train the model and get the predictions out. But the catch is this. RA-DAE is able to add neurons or remove neurons from layers of the network on-the-go.

So in our little example we talked above, ideally RA-DAE would keep the network constant while learning images of digit 1 where then it will start adding neurons to the network to accommodate images of digit 2 as soon as the network start seeing those images.

How cool is that? Okey, it’s not so cool. Actually this paper does the exact same thing. Aaaah! but that paper uses immediately error of the neural network to decide if the network needs more neurons or not. And it adds or removes a predefined number of neurons according to few simple rules (i.e. short-term accuracy gain). Which doesn’t sound that cool.

This is where RA-DAE shines! RA-DAE uses reinforcement learning to continuously update a value function that says how good an action (add/remove neurons) is given the state the network is currently in. So we are no longer looking at the immediate error (or accuracy), but a long term error (or accuracy).

## Stacked Denoising Autoencoder

As I promised before, let’s look at nitty-gritties RA-DAE in a bit more details. First of all on a abstracted level, RA-DAE is essentially an Stacked Denoising Autoencoder(SDAE). To understand a SDAE, let us assume a dataset where is a tuple containing data (e.g. an image unrolled to a single dimension of elements) and its corresponding one-hot encoded label from a total classes.

In order to train this contraption, we optimizing two losses; a generative loss and a discriminative loss . In plain english, the generative loss or the reconstruction error is defined for the network to learn a good compression of the data. This is done by forward propagating a data point up to the last hidden layer (having a smaller dimensionality than the input space) and trying to reconstruct the original input from the representation we have in the last hidden layer. Next, the discriminative loss is responsible for doing the actuall vision task (i.e. image classification) you are using the SDAE for. Mathematically these two losses turn out to be the following equations.

So nothing fancy, just two cross-entropy losses.

And then we optimize the SDAE weights and biases in all the layers w.r.t these two losses using a Optimizer (e.g. SGD, Momentum, NAG, Adam). Oh! and the **denoising** part is just a regularizer where we apply some noise to the original image and try to construct the noise-free original image from the hidden representation obtained from the noisy image.

## RA-DAE

### Problem Formulation as a MDP

Now let’s see how we convert a SDAE into this fancy looking RA-DAE. The most imperative component of RA-DAE is the elegant adaptation mechanism it uses. In order to incorporate Reinforcement Learning (RL) to this problem of adapting a neural network, we define the problem as a Markov Decision Process (MDP) (Chapter 3 of Reinforcement Learning: An Introduction). To summarize, a MDP cosist of 4 important things.

- A state space – (A set of states)
- An action space – (A set of actions)
- A transition function – (denotes the probability of going from state from state by taking action )
- A reward function – (reward for taking action in state )

Solution to a MDP would be a optimal policy that returns the best action for state .

### How do we fit a MDP to RA-DAE?

How do we define the problem at our hands as a MDP? We define the the following way. Note that we cannot define a transition function as our environment is too complex to define or derive a transition function.

Where and is the rolling (i.e. moving exponential) average of the loss and within the last data batches at the iteration.

Where add new nodes (i.e. neurons) and greedily initialize them with a pool containing last data batches seen so far. merges pairs of nodes having the least cosine distance inbetween, effectively removing nodes. Furthermore we define as follows.

Though the equation looks intimidating, the idea is simple this equation is designed to prevent the network from growing too large or small.

Reward funciton not only tries to reduce the error but also try to increase the rate the error is reducing

### Learning a Policy with Q-Learning

Now since we have and define, it is possible to learn a policy that gives the best action for a given the state . To learn this policy we are using a temporal difference technique called Q-Learning (Chapter 6 of Reinforcement Learning: An Introduction). To learn the exact way the algorithm is used please refer Section 3.4 and Algorithm 3 of the paper.

### How do we deal with the continuous state space?

Now you might notice that our state space is continuous. But Q-learning algorithm is straight-forward only for discrete action and state spaces. What do we do to use Q-learning for continuous state space? Answer **Fucntion Approximation**. The idea is that for an unseen state the Q-value is obtained as where are the parameters of the approximator. In the paper we’re using a Gausian Process Regressor (GPR). To do this, we keep a list of tuples (we call this Experience) of format where is a state we’ve seen and is the q-value we got for taking action in state . Then using the collected experience, for each action we fit separate GP curves. And for a new unseen state we regress for all and select the action with the maximum Q value.

## Putting Everything Together

I know this is not the most comprehensive explanation about RA-DAE but you can always refer the original paper to see the exact details (Section 3.4). But I’ll try to elucidate to my best how things are put together to form a coherent entity. I will write this down with python syntax to give a fresh perspective instead of restating things from the paper.

# s_n - Current state # s_n_minus_1,a_n_minus_1 - Previous state and action # Q - List of tuples e.g [(s_1,q_1),(s_2,q_2),...] # gpr_collection[a] - GP regressor for action a # w - Selected hyperparameters for all GPRs # epsilon - Epsilon for e-greedy # fit_interval - Interval we update all the GPR # experience - Currently accumulated experience gpr_collection = [GPR(w) for _ in range(len(A))] for n in range(E): s_n = calculate_state() r_n = calculate_reward(s_n_minus_1,a_n_minus_1) if n < eta_1: trainSDAE(B_r) #Train SDAE with data pool B_r elif n < eta_2: a = n%len(A) # uniformly sample action a from A collect_experience(s_n,r_n) # collect experience for fitting GPRs else: # q update q_vector_last = [gpr_collection[a].predict(s_n_minus_1) for a in A] q_last = (1-alpha)*q_vector_last[A.index(a_n_minus_1)] + alpha * (r_n + gamma*max(q_vector_last)) collect_experience(s_n_minus_1,q_last) # collect experience for fitting GPRs if n_rand < epsilon: a_prime = rand_choice(A) # randomly select an action from A else: # calculations for n q_vector_new = [gpr_collection[a].predict(s_n) for a in A] a_prime = A[argmax(q_vector_new)] # get the action with highest q value # fit all the GPRs every fit_interval. Should not update every iteration # as that can lead to an explosion of Q values if n%fit_interval==0: for a in A: gpr_collection[a].fit(experience)

## Results obtained with RA-DAE compared to SDAE and MI-DAE

These are some results obtained by comparing accuracies obtained for MNIST, CIFAR-10 and MNIST-ROT-BACK datasets. We also wrote a script to generate datasets with varying class distributions over time (sub-image (f) of results) using gaussian priors over classes. And three algorithms were compared RA-DAE (our approach), MI-DAE and SDAE.

This is pretty much it. So I hope this was useful to you and the code is available here. Just a heads up! The codes is not as readable as you might expect. But I intend to clear it out to make it more reader-friendly.