C4_W3_Lab_1_VAE_MNIST: Question about binary loss in training loop

Hey there,

maybe I am wrong, correct if it was, in C4_W3_Lab_1_VAE_MNIST lab.

So according to the doc, we can use add_loss, when the loss is related to inputs, right?

In the training loop section, the origin code:


      # compute reconstruction loss
      flattened_inputs = tf.reshape(x_batch_train, shape=[-1])
      flattened_outputs = tf.reshape(reconstructed, shape=[-1])
      loss = bce_loss(flattened_inputs, flattened_outputs) * 784

which works for binary loss is actually related to the inputs and final decoded outputs by flattening them.

so, my idea is to move those to vae_model method:


 def vae_model(encoder, decoder, input_shape):
  
  inputs = tf.keras.layers.Input(shape=input_shape)
  mu, sigma, z = encoder(inputs)
  reconstructed = decoder(z)

  inputs_flatten = tf.keras.layers.Flatten()(inputs)
  reconstructed_flatten = tf.keras.layers.Flatten()(reconstructed)
   
  model = tf.keras.Model(inputs=inputs, outputs=reconstructed)
 
  loss = kl_reconstruction_loss(inputs, z, mu, sigma)
  model.add_loss(loss)
  
  # NEW: Add the binary_crossentropy loss between inputs and decoded outputs.
  loss = tf.keras.losses.binary_crossentropy(inputs_flatten, reconstructed_flatten) * 784
  model.add_loss(loss)

  return model

and then the training loop will be simplified like:


    with tf.GradientTape() as tape:

      # feed a batch to the VAE model
      vae(x_batch_train)
      
      # add KLD regularization loss
      loss = sum(vae.losses)  

    # get the gradients and update the weights
    grads = tape.gradient(loss, vae.trainable_weights)
    optimizer.apply_gradients(zip(grads, vae.trainable_weights))

    # compute the loss metric
    loss_metric(loss)

Will it be a problem comparing the original solution, the training was running in the same as the original one, but I am not sure if I am on the right page?

Thanks :man_shrugging:t2::man_shrugging:t2::man_shrugging:t2::man_shrugging:t2::man_shrugging:t2:

Need to run the model with this new modification to see if it behaves the same.