Week 2: Implementing Callbacks in TensorFlow using the MNIST Dataset

Hi,

I am trying to implement an solution to week 2 assignment I have successfully implemented the callback function and have instantiated my callback class.

Here is the code for my callback function:

class myCallback(keras.callbacks.Callback):
        # Define the correct function signature for on_epoch_end
        def on_epoch_end(self, epoch, logs={}):
            if(logs.get('acc')>0.99):
                print("\nReached 99% accuracy so cancelling training!")
                self.model.stop_training = True


After that I successfully call the callback function and define the model:

> def train_mnist(x_train, y_train):
> 
>     
>     
> 
>     ### START CODE HERE
>     
>     
>     # Instantiate the callback class
>     callbacks = myCallback()
>     
>     # Define the model
>     model = tf.keras.models.Sequential([
>             # YOUR CODE STARTS HERE
>         tf.keras.layers.Flatten(input_shape=(28, 28)),
>         tf.keras.layers.Dense(512, activation=tf.nn.relu),
>         tf.keras.layers.Dense(10, activation=tf.nn.softmax)
>             # YOUR CODE ENDS HERE
>         
>         
>     ]) 
>     
>     # Compile the model
>     model.compile(optimizer='adam',                   
>                   loss='sparse_categorical_crossentropy',                   
>                   metrics=['accuracy'])     
>     
>     # Fit the model for 10 epochs adding the callbacks
>     # and save the training history
>     history = model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])
> 
>     ### END CODE HERE
> 
>     return history


But when I run the following command:

> hist = train_mnist(x_train, y_train)

The following error is produced:

> ---------------------------------------------------------------------------
> TypeError                                 Traceback (most recent call last)
> <ipython-input-44-669c19a8f225> in <module>
>       1 # grader-required-cell
>       2 
> ----> 3 hist = train_mnist(x_train, y_train)
> 
> <ipython-input-43-9f98fb227b18> in train_mnist(x_train, y_train)
>      31     # Fit the model for 10 epochs adding the callbacks
>      32     # and save the training history
> ---> 33     history = model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])
>      34 
>      35     ### END CODE HERE
> 
> /opt/conda/lib/python3.8/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
>      65     except Exception as e:  # pylint: disable=broad-except
>      66       filtered_tb = _process_traceback_frames(e.__traceback__)
> ---> 67       raise e.with_traceback(filtered_tb) from None
>      68     finally:
>      69       del filtered_tb
> 
> <ipython-input-38-18d6a88f32e2> in on_epoch_end(self, epoch, logs)
>      10         # Define the correct function signature for on_epoch_end
>      11         def on_epoch_end(self, epoch, logs={}):
> ---> 12             if(logs.get('acc')>0.99):
>      13                 print("\nReached 99% accuracy so cancelling training!")
>      14                 self.model.stop_training = True
> 
> TypeError: '>' not supported between instances of 'NoneType' and 'float'

What happens if you use the same key everywhere, vs accuracy in the model setup but acc in the callback?

@ai_curious

I dont understand

That is an indexed lookup on the key acc

But you compiled the model using the keyword accuracy

I didn’t write a test harness myself, but if I did, I would see what would happen if they were identical.

1 Like

Hi,

I edited the callback function as shown:

> class myCallback(tf.keras.callbacks.Callback):
>         # Define the correct function signature for on_epoch_end
>         def on_epoch_end(self, epoch, logs={}):
>             if(logs.get('acc') is not None and logs.get('acc') >= 0.99):
>                 print("\nReached 99% accuracy so cancelling training!")
>                 self.model.stop_training = True

This solves the previous error

Ok, but maybe that always keeps the rule from firing, because logs.get(‘acc’) is always None. Are you ever seeing the early stopping message?

I think I have my answer…