C2W2 VAE - mean and stddev

For the VAE encoder, to generate the mean and stddev, the code has this:

Can some one please explain how “encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp()” calculate the means and stddevs of Encoder output? My understanding is encoding has the shape of (batch size, output_channels = 32).

Thanks!

Hi @sooolee,

At a high level, the mean and stddev values are learned by training the model - by repeatedly looking at the losses and doing backprop.

As far as the specific values returned by the encoder’s forward method:

If you look at the last line of the Sequential model for the encoder, notice that it actually outputs output_chan * 2 values for each item in the batch:

         self.make_disc_block(hidden_dim * 2, output_chan * 2, final_layer=True),

This is because we want two values per output channel: the mean and the stddev. The first z_dim values of encoding[] are for the mean, and the second z_dim are for the stddev, or actually the (natural) log of the stddev.

The reason you see the .exp() for the second value returned by forward is to convert the log value from encoding into the stddev we want to return. The convenient thing about learning the log value in the neural network is that then any real value is valid, so we can legitimately have negative values. Then we take .exp() in the forward method to convert to stddev, which ensures it is always positive, as it should be.

Thanks a lot, Wendy. Treating the stddev as log values makes sense.

And yep, the encode output have output_chan *2. So regarding the training and learning, I’m still fuzzy about how the networks learn the first 32 channels to be mean and the second 32 variances… but conceptually, my understanding is: it learns that we make normal distributions with these values where we sample z (using rsample) as backprop goes through this distribution and sampling process? Am I close?

I have no idea how that’s learned mathematically, but I am assuming that the networks understands we use the first 32 as mean and second 32 as stddev in normal distributions…

Hi @sooolee,

The general idea for how it learns the stddev and mean values is much like it works for other models. The network doesn’t really know about the math of stddev or means, just as an animal classification model doesn’t have any inherent knowledge about dogs or cats. It’s all about how we calculate loss to estimate how close the network is to the correct answer and then the backprop to get us closer and closer to the answer we want as we loop through the training. For VAE there are some adaptations to come up with functions that are viable for backprop, as explained in the assignment, but the general point is that these models learn in basically the same way as any other model.

2 Likes

Thanks for the additional explanation!! :slight_smile: