Plenty of people have trouble with the syntax of Keras model creation. What’s up with the square brackets? The commas? That weird (x) thing at the end? Without going too deep on a TensorFlow tutorial, I thought I’d share a few things that might help simplify what’s going on.
First, it’s important to remember that TensorFlow is based on a graph model. Even though eager execution now allows evaluation of operations immediately, without building graphs, the graph influence is pervasive. I find it particularly true when creating Keras models. Think of the layers as nodes in the graph, and the data flow between layers as the graph edges. Hopefully this makes more sense below.
Let’s start with the Sequential model. According to the documentation, a Sequential model is appropriate for a plain stack of layers where each layer has exactly one input tensor and one output tensor. Here’s how you set up a Sequential model:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
#define a model architecture using Sequential
sequential_model_def = keras.Sequential(
[
layers.Dense(64, activation="relu", name="dense_1_sequential_def"),
layers.Dense(64, activation="relu", name="dense_2_sequential_def"),
layers.Dense(10, name="predictions_sequential_def"),
]
)
A simple model with 3 fully connected layers. Let’s see what it looks like as a graph:
keras.utils.plot_model(sequential_model_def, "visualize_sequential_model_no_input_yet.png", show_shapes=True)
Hmm, not very interesting. What’s going on? Turns out that the Sequential model constructor does a lot of work for you under the covers. One of which is to dynamically build the weights matrices. But weight matrix shapes depend on the shape of the input, which we haven’t provided yet. As a result, the model isn’t considered ‘built’ until you provide it with the input shape.
#define an input layer that can be used for all examples
input_layer = keras.Input(shape=(784,), name="input_layer")
#inform the model about its inputs
sequential_model_def(input_layer)
#now try visualizing again
keras.utils.plot_model(sequential_model_def, "visualize_sequential_model.png", show_shapes=True)
Take a moment to confirm that the input and output shapes of the layers match the definition in the code above. Also notice that the directed graph was built to match the order of the comma delimited list of layers. This is where the square brackets come in to play. You are passing the model constructor a Python list that is defined on the fly using square brackets, separating items with commas: [a, b, c]. The dependencies (data flow/graph edges) between the model layers (graph nodes) are automatically inferred from the order in which they appeared in that list.
There is an alternative method for conveying input shape to your model. You can pass it as an optional argument to the first layer. That is:
layers.Dense(64, activation="relu",name="dense_1_sequential_def", input_shape=(784,)),
NOTE: what you cannot do is provide the input shape like this:
data_format=(784,)
The data_format
parameter is a string, not a tuple. The only acceptable values are channels_last
, the default, or channels_first
.
You can accomplish the same thing using the Sequential API with a different style, namely the add() function. Here is the code and the graph it produces:
#what about using the add() function?
sequential_model_def_2 = keras.Sequential()
sequential_model_def_2.add(layers.Dense(64, activation="relu", name="dense_1_sequential_def_2"))
sequential_model_def_2.add(layers.Dense(64, activation="relu", name="dense_2_sequential_def_2"))
sequential_model_def_2.add(layers.Dense(10, name="predictions_sequential_def_2"))
sequential_model_def_2(input_layer)
keras.utils.plot_model(sequential_model_def_2, "visualize_sequential_model_2.png", show_shapes=True)
I changed the layer names to def_2, but otherwise, exactly the same. The add() function is acting like an append(), in that the new layer appears at the end of the list (NOTE: IIRC the internal data structure is actually a stack). You can see why the Sequential model only supports simple models: it is a limitation of using the Python List data structure to define the topology.
To overcome this limitation there is the Functional API. The Keras Functional API is a way to create more flexible (ie complex) models. It can handle non-linear topology, shared layers, and multiple inputs and outputs. Unlike the sequential model, where the simple topology can be inferred from the Python list of layers, in the functional API you explicitly build the graph of layer nodes. Here is an example, using the same input layer and the same fully connected shapes:
#now let's define the same architecture using the functional API
l1 = layers.Dense(64, activation="relu", name="dense_1_functional_def")(input_layer)
l2 = layers.Dense(64, activation="relu", name="dense_2_functional_def")(l1)
predictions = layers.Dense(10, activation="softmax", name="predictions_functional_def")(l2)
functional_model_def = keras.Model(inputs=input_layer, outputs=predictions)
keras.utils.plot_model(functional_model_def, "visualize_functional_model.png", show_shapes=True)
Again, other than the labels, the graph is the same. The layer shapes are provided the same way in both APIs. What is different here is that the layer dependencies are explicitly added inside the trailing parenthesis. That is, the l1 layer is connected to the input_layer by writing…
l1 = layers.Dense(...)(input_layer)
The l2 layer is connected to the l1 layer by writing…
l2 = layers.Dense(...)(l1)
and so on. The first section of the expression is the constructor; it builds the new layer instance node. The second part, with the (x), connects the newest node to some other layer. Once the model is built, there is no difference in the way they are handled or treated. It’s only the syntax for model definition that varies. Hopefully this helps make sense of why and how to deal with the differences.