Output layer of BERT

BERT pretraining combines the loss of masked token prediction and next sentence prediction . However, the model shown in the video only shows the tokens in the output layer. There are no outputs for next sentence prediction. Hence, there does not seem to be a possibility of calculating the binary-cross entropy loss for next sentence prediction.

Or is it that the next sentence target value is also just a token in the input which is always masked?

Hi @Ritu_Pande

I’m not sure I fully understand you question.

From [the paper] Apendix:

Next Sentence Prediction The next sentence
prediction task can be illustrated in the following
Input = [CLS] the man went to [MASK] store [SEP]
he bought a gallon [MASK] milk [SEP]
Label = IsNext
Input = [CLS] the man [MASK] to the store [SEP]
penguin [MASK] are flight ##less birds [SEP]
Label = NotNext

So the “C” token (the token in input [CLS]) should assign high probability when the B sentence is natural to be next to sentence A and low when the sentences are not next to each other.

Did I misunderstand something in your question?

The number of units in the last Dense layer for all tokens has size V ( am I correct? ). While predicting tokens we take argmax of all the vocab tokens. But it is not clear to me how the value of the sentence classification is derived from V values.

Maybe I am missing something basic here, but I am not sure how the last output layer is designed

That is not true. We actually do not take the argmax but we check what is the probability of the target word for [MASK] - if the probability is high - loss is low, if the probability is low - loss is high. In other words, we don’t care what the argmax is, maybe it’s another word, but the model weights are updated only for the target word (in place of the [MASK]) - what is the probability for the target word in that place.

There are variations for “the last output layer” but in general, for the “C” token (or [CLS]) in the next sentence prediction - the [CLS] token is transformed into a 2x1 shaped vector using a simple classification layer. For example, if the “C” output is 33 000 vector, you can multiply it with 33000x2 matrix and get the prediction for being IsNext or NotNext and update the weights accordingly (you can do this because the autograd will take care of tracking gradients).

From the paper: The overall training loss is the sum of the mean masked LM likelihood and the mean next sentence prediction likelihood.
In simpler terms, when training BERT, Masked Language Model and Next Sentence Prediction are applied together with the combined loss function (the autograd takes care of tracking which weights influenced which predictions - both the MLM and NSP).

Do you have any diagram of model architecture that shows an example of last output layer for the BERT model ? Not just for CLS token but for all tokens in the output

What do you mean by:


In your posted image, the green rectangles are “all tokens in the output”. Every green rectangle is 768 long vector (in case of default BERT Base model).
And, as I mentioned, for the next sentence prediction during training, the “C” rectangle is passed through the next sentence prediction head (simple Linear layer) to get the IsNext or NotNext prediction)
Also, some T_i tokens (those in the [MASK] place) are pulled out to check their predictions - do they correctly predict the word in that place.

For me, diagrams are hard to understand before I’m not aware of the underlying calculations (or code). So I would encourage to go through the code and try to understand what the code does only then look at the diagrams.
You can search on google “bert output layer” then choose “Images” and you will find many examples of the output layer diagrams, but I think the understanding comes from implementing the thing (an example) and all the diagrams starts to make sense.

P.S. I remembered an excellent blog post by Jay Alammar who uses a lot of illustrations. If your preference is to learn from illustrations/diagrams that is a great post.

Thanks for sharing the pointers and for your patience in explaining the concepts to me. I will go through the links.

No worries @Ritu_Pande. These models (BERT, T5) was hard for me to understand too because there implementations are not very straightforward (special tokens, multiple tasks, etc.).

@arvyzukai Thanks for sharing the links. I now understand the BERT implementation properly. Could you please share a code example and blog of basic T5 implementation similar to how you shared for BERT?


I’m happy to help @Ritu_Pande . I don’t have a specific example for basic T5 implementation but a simple google search (“basic t5 implementation github”) could give you some results to check out (like this or this). But you have to be cautious because they might not be fully correct.
On the other hand, this T5 code should be correct although more complicated to follow.


I had searched, but wanted to check if you had a simple and already curated implementation available. Thanks for the help