Update_state is not used in Lab and screencast

Inside of train_data_for_one_epoch function - update_state is not called on the metric object in the lab, but is called in the Gradients, metrics, and validation video.

In the lab, following code is used -

def train_data_for_one_epoch():
  losses = []
  pbar = tqdm(total=len(list(enumerate(train))), position=0, leave=True, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} ')
  for step, (x_batch_train, y_batch_train) in enumerate(train):
      logits, loss_value = apply_gradient(optimizer, model, x_batch_train, y_batch_train)
      
      losses.append(loss_value)
      
      train_acc_metric(y_batch_train, logits)
      pbar.set_description("Training loss for step %s: %.4f" % (int(step), float(loss_value)))
      pbar.update()
  return losses

while in the Gradients, metrics, and validation video, this is the code -

def train_data_for_one_epoch():
  losses = []
  for step, (x_batch_train, y_batch_train) in enumerate(train):
      logits, loss_value = apply_gradient(optimizer, model, x_batch_train, y_batch_train)
      
      losses.append(loss_value)
      
      train_acc_metric.update_state(y_batch_train, logits)
  return losses

How is the metric object getting updated in the lab?

Hi there,

The update_state is the method of the class Metric and is implemented in the subclass (in our case SparseCategoricalEntropy etc.), when the update_state is called it stores the results in the class object, in the other hand calling train_acc_metric object of the class will do the update as well. It just calling a function vs. an object of a class. Have a look here: Update_state method

1 Like

It might be worth to add this infromation in a note or in the lab.

1 Like