Lab link: Coursera | Online Courses & Credentials From Top Educators. Join for Free | Coursera
This lab suggests an interesting way to conceptualise relu layers as approximating a complex curve, by multiple piecewise segments, with each node corresponding to one segment:
My hunch is that this interpretation is very useful in some cases and not at all useful in others. So I was interested to explore further with some experiments.
In a Google Colab notebook, I have trained a simple model against the MNIST fashion dataset, containing 60,000 28x28 greyscale images of various fashion items. I have then attempted to visualise the “curve” of the first layer.
However, despite the layer in question having 128 neurons, I still only got very simple curves. Here’s a sample of curves plotted against 36 different input features:
Google Colab notebook here - see the later sections starting from “Multi-plot”:
In this case, the model is structured as follows:
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
I’m looking at the activations of the first Dense layer.
It has 28x28 = 784 inputs and 128 outputs.
To interpret something as a curve, I’ve focused on a single input pixel as a feature of interest, which I have varied in its range from 0.0 to 1.0 as the x-axis in my curve plot. I simulate 100 images with that pixel varying in that range, and hold all other pixels fixed at their means.
Why do this? I’m taking those 128 neurons as the generators of the “segments” in this complex curve that I’m assuming its approximating. I want to vary one input feature and see what curve it generates against that feature.
The rest of the notebook was all about trying to troubleshoot and to find which pixels would generate interesting curves.
So here’s my questions to the community: Why are these curves so boring? With 128 neurons, I want to see a really complex curve. Why don’t we get that? What other problem domains are better suited to the “curve” interpretation, and would produce more interesting curves?
I’d like to offer a partial answer to one my questions. I think the curves are boring because these 128 neurons are not modelling just one curve, but 748 curves – or more accurately, a curve in 748 dimensional space. That means two things. Firstly, gradient descent will have found a sort of “middle-ground” that finds a solution that works somewhere in a sort of averaged world that models all of the 748 dimensions. It’s like each individual neuron is having to act as a single cell in a hologram - to cater to multiple representations depending on which angle you look at it from. Secondly, for visualisation purposes I’ve dropped the scales, but they’re probably incredibly small. When you try to vary just one feature in 748, you’re not going to get a strong impact.
So, what other kinds of problem domains and models would produce more interesting curves?