Inside of train_data_for_one_epoch function - update_state is not called on the metric object in the lab, but is called in the Gradients, metrics, and validation video.
In the lab, following code is used -
def train_data_for_one_epoch():
losses = []
pbar = tqdm(total=len(list(enumerate(train))), position=0, leave=True, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} ')
for step, (x_batch_train, y_batch_train) in enumerate(train):
logits, loss_value = apply_gradient(optimizer, model, x_batch_train, y_batch_train)
losses.append(loss_value)
train_acc_metric(y_batch_train, logits)
pbar.set_description("Training loss for step %s: %.4f" % (int(step), float(loss_value)))
pbar.update()
return losses
while in the Gradients, metrics, and validation video, this is the code -
def train_data_for_one_epoch():
losses = []
for step, (x_batch_train, y_batch_train) in enumerate(train):
logits, loss_value = apply_gradient(optimizer, model, x_batch_train, y_batch_train)
losses.append(loss_value)
train_acc_metric.update_state(y_batch_train, logits)
return losses
How is the metric object getting updated in the lab?