Generative Models Part I: Variational Autoencoders

In the next few posts, I will be talking about several Generative Models that saw daylight quite recently. On a general tone, generative models are used to generate new samples by approximating the underlying distribution of data and quite useful for unsupervised learning and semi-supervised learning. Few state-of-the-art generative models are Variational Autoencoders (VAEs) by Kingma et. al and Generative Adverserial Networks (GANs) by Goodfellow et. al. Generative Adveserial Networks (GANs) shines when it comes to creating photo-realistic images.

Variational Autoencoders (VAE)

Variational Autoencoders is a model that composes of 2 components; an Encoder f(X;\phi) and a Decoder f(z;\theta). The encoder maps X into a latent variable z in the latent space Z. The encoder f(X;\phi) and decoder f(z;\theta) both are powerful function approximators (e.g. Neural Networks). This pushes the encoder to learn a probability distribution of z which we denote by Q(z|X,f(X,\theta)). The decoder samples from the latent space z \sim Q(z|X,f(X,\theta)) and generate a sample x. Sample x is sampled as x \sim Q(X|z,f(z,\phi)) where Q(X|z,f(z,\phi)) is the approximated probability distribution for X learnt by the decoder. This is what happens under the hood in the VAE.

Now, we dive head first into the technical details. I will try to be as thorough as possible while keeping an easy-to-follow flow of how things fall into pieces. Beyond this point, you must clear all thoughts and enter a state of peace. Seriously though, forget the above explanation. Because we are about to embark on the journey that will explain what exactly VAEs are.

Key Idea behind VAEs?

Before coming up with the mechanism that dictates VAEs, let’s first understand what is the ultimate goal of VAEs. The goal of VAEs is to, find a distribution Q(z|X) of some latent variable which we can sample from (z \sim Q(z|X)), to generate new sample \tilde{x} \sim P(X|z,f(z,\theta)) (to be precise, an approximated P(X)) through the decoder (a function approximator f(z,\theta) such as a Neural Network). I know, I know … This statement raises more questions than what it answers. One key point you should take away from this point onwards is that, the decoder maps (or decodes) the latent variables z to a new sample that comes from P(X).


Let’s make this a little less Einsteiny…

This is harder than I thought. First of all I’m going to explain this both probabilistic and non-probabilistic. But after that, I’m going to stick with probabilistic interpretation, otherwise it is hard to explain the theory. So you better pull out your probability drawers in your brain if they’ve been rusting away.

Probabilistic Interpretation

So let’s set our foot on this, we want to find the distribution of some latent (i.e. hidden) representation Q(z|X), of P(X) (P(X) is some underlying distribution of data we don’t know about). z should capture some semantic features of X. This allows us to generate new sample \tilde{x} the following way using the decoder f(z,\theta).

  1. Sample \tilde{z} from Q(z|X) (i.e. \tilde{z}\sim Q(z|X))
  2. Sample \tilde{x} from \tilde{x} \sim P(X|z,f(z,\theta))

Non-Probabilistic Interpretation

We want to find some latent (i.e. hidden) representation z of the data X. z should captures some semantic features of X. So that by looking at the values of z we should have an idea about what data that generated this z should look like. This allows us to generate new sample \tilde{x} by tweaking values of z (that will result in a new \tilde{z}) and using the decoder f(z,\theta) where \tilde{x} = f(\tilde{z},\theta).

Digit Example

For example, if we want to generate images of digits, by observing images (i.e. X) we define some z that define a set of variables like z=(stroke_thickness, angle, scale, …) so we can slightly change the values of z and generate new samples.

This is a quite nice approach with some solid theoretical foundation. But there’s a few important questions left to be answered, which we will do below.

Problem: How do we find Q(z|X)?

Easier said than done! Unfortunately, it’s not as easy as it looks to design Q(z|X). If you attempt to hand engineer z you will run into a multitude of problems (What are the variables to consider? What is the importance of each variable?). And to make matters worse, these variables might be correlated. From the Digit example, a smaller stroke_thickness would result in a higher angle. And even if you figure all this out, it would take years to create the dataset for z. So, by the looks of it, it is not a good idea to design this by hand.

Solution: Let prior P(z) = \mathcal{N}(0,I)

Yes, you saw it correctly! We assume, that z \sim \mathcal{N}(0,I). That is we assume z comes from a zero-mean, unit-variance normal distribution. Sounds crazy, but it works. That is because it is not as crazy as it looks. Here’s how this works.


\mathcal{N}(0,I) is a powerful distribution. With “powerful” enough function g(z) we can turn z into an arbitrary complex distribution. This tutorial (Figure 2) highlights a nice example of that. This means that, if we have correct g(z) we can map z to the meaningful semantic features we want.


Now we got another question at our hands, what is g(z)?

What is g(z)? g(z) is a part of f(z,\theta)

Remember, f(z,\theta), our decoder is a function approximator. So why not approximate g(z) with that. To add more clarity, if you think the decoder is a four-layer neural network and assume the digit example, the first two layers will map z into the meaningful latent variable space we defined earlier. Then the next two layers will convert those latent variables to full-rendered image of the digit. Of course that’s not exactly the way it works. But it is a nice intuitive way to see how things finally become coherent.


Sampling from z: Curse of Dimensionality

From above discussions, you know that we need to z \sim Q(z|X) to generate new samples. But we have some sort of a pragmatic problem in front of us. If you just sample z \sim \mathcal{N}(0,I) for most z, P(X|z) will be zero. This phenomenon is called the curse of dimensionality. So we have another problem at our hands. How do we find an effective way to samples z so that, z will come from the space where P(X|z) will be non-zero.

Solution: Sample from a posterior Q(z|X) obtained with a function approximator

It seems that there’s a smarter way to get around sampling, instead of waiting a millennium to find z that gives P(X|z) non-zero. Why not use a function approximator f(X,\phi) to find the distribution Q(z|X), that results in non-zero P(X|z). What is this f(X,\phi)? Ah, things are finally falling into places. f(X,\phi) is the Encoder we talked about earlier. More formally, f(X,\phi) will output \mu(z,\phi) and \sigma^2(z,\phi) that will be the mean and variance of the isotropic gaussian we will be sampling z from.

Finally, the complete picture

So we came all the way from the defining our goal to why we need the decoder to why we need the encoder. Here’s an image of how things look together.


So in order, we do the following.

  1. Sample x i.e. x \sim P(X)
  2. Sample z i.e. z \sim Q(z|X,f(X,\phi))=\mathcal{N}(\mu(z,\phi),\sigma^2(z,\phi)
  3. Sample \tilde{x} i.e. \tilde{x} \sim P(X|z,f(z,\theta))

Hold on! P(X|z) and Q(z|X) needs to be learnt

By this time you probably have realized you need some basis to train the encoder and the decoder. I’m going to assume neural networks. For that we are going to use maximum likelihood approach. The intuition, is that we want to find values of \theta and \phi so that we are likely to get data like X (or maximize P(X)). Now we will set up our objective function along this line. In other words, we need to find some PDF of z, Q(z) that are likely under data P(X). Mathematically, we define this as the divergence between Q(z) and P(z|X). Specifically Kullback-Leibler Divergence.

D_{KL}[Q(z)||P(z|X)] = E_{z\sim Q}[log(Q(z))] - E_{z\sim Q}[log(P(z|X))]

After some number crunching, you will come at the following equation (See Section 2.1 for exact maths),

\text{log}(P(X)) - D_{KL}[Q(z|X)||P(z|X)] = E_{z \sim Q}[\text{log}(P(X|z))] - D_{KL}[Q(z|X)||P(z)]

LHS of the Equation

So by maximizing the RHS, we will be maximizing the log likelihood of P(X) while minimizing a term denoted by D_{KL}[Q(z|X)||P(z|X)] simultaneously. However, the term D_{KL}[Q(z|X)||P(z|X)], cannot be ignored as this minimization (when this reaches 0) of this term gives the additional benefit of approximating the intractable P(z|X). Moreover, it has been proven (for 1D case) that, this term in fact reaches zero given that \sigma^2(X,\phi) is small.

RHS of the Equation

We need some tangible values for the RHS if we want to maximize the likelihood of P(X). Well, E_{z \sim Q}[\text{log}(P(X|z))] is something you can write blindfolded if you into deep networks. It’s the reconstruction error between X and f(z,\theta).

Next, the second term can be computed in closed-form. It’s the KL divergence between two isotropic Gaussians given by,

\mathcal{D}_{KL}[\mathcal{N}(\mu(z),\sigma^2(z)||\mathcal{N}(0,I)] = \frac{1}{2}(\text{tr}(\sigma^2(z)) + \mu(z)^T\mu(z) - k - \text{log det}(\sigma^2(z)))

Now by putting these two terms together, the RHS becomes easy to implement. I won’t state the obvious here (The full equation).


Now it’s just a matter of optimizing this objective function w.r.r \phi \text{ and } \theta using a stochastic gradient method (e.g. Momentum update, Adam optimizer). Right now, this is what our model looks like. The solid lines represent connections that results in a deterministic output in the model, where dashed lines show connections that would result in stochastic outputs.


There’s a major issue that does not allow us to backpropagate through the whole model. Do you see it? Since we are sampling z from a Distribution, the end-to-end deterministic nature of parameters is lost. In other words, backpropagation will work with stochastic inputs, but not stochastic parameters. If we speak in terms of the above image, it is fine to have those dashed lines at the ends of our model, but we need solid lines from end-to-end in our model, otherwise we cannot backpropagate end-to-end.

Solution: Reparameterization Trick

Don’t let your hopes down. There’s a elegant trick to level out this lump of misfortune. That is called the “reparameterization trick”. Instead of sampling z as below,

z \sim \mathcal{N}(\mu(z,\phi),\sigma^2(z,\phi)

we do the following

z = \mu(X,\phi) +  \sigma(X,\phi)\epsilon where \epsilon \sim \mathcal{N}(0,I).

the broken link between the encoder and decoder disappears enabling the backpropagation to work as good as ever.


Wrapping Up

Yep, this is really it. That’s all to basics of VAEs. I know it’s a mouthful of mathematics and probability and what not. But hopefully, after going through a few times and implementing by your own, things will make sense.

Tips and Tricks

When implementing VAEs be mindful of your implementation because if you are using linear unit activations, make sure you have a good initialization (e.g. Xavier Initialization), otherwise the model will converge very quickly to a poor solution, probably because of dead linear units.


My code can be found here.

Light on Math Machine Learning: Intuitive Guide to Understanding KL Divergence

I’m starting a new series of blog articles following a beginner friendly approach to understanding some of the challenging concepts in machine learning. To start with, we will start with KL divergence. Code: Here Concept Grounding First of all let us build some ground rules. We will define few things...

A Paper in Thousand Words: Neural Architecture Search with Reinforcement Learning

Paper is found here. One of the key advantages of Deep Models is that they made feature engineering obsolete. With this came a paradim-shift; from engineering robust features to engineering deep architectures, i.e. hyperparameters, for machine learning tasks. This paper uses reinforcement learning (RL) to find the best deep architecture...

RA-DAE: Structurally Adaptive Deep Architecture inspired by Reinforcement Learning

In this post, I’m going to introduce a type of a Stacked Autoencoders (SAE) (Don’t worry if you don’t understand what an SAE is. Will explain later.). And worth a mention, that this is some research work done by me and few colleague from our research lab. So yay for...