Context Mask during Training

Could someone please explain why, during the training phase, a mask for the context vector is needed, and how it’s used?
And why the mask is not required during inference/sampling time?

Thanks for the help!

During the training phase, an attention mechanism is often employed to allow the model to focus on different parts of the input sequence while generating the output sequence. The context vector is a representation of the input sequence’s information that is used to influence the generation of each token in the output sequence.

In the context of attention mechanisms, a mask for the context vector is used for a specific reason: to prevent the model from attending to parts of the input sequence that it shouldn’t have access to during training. This is particularly important when generating sequences autoregressively, one token at a time, where each token depends on the previous tokens. Without such masking, the model could inadvertently “cheat” by looking ahead at tokens it’s supposed to generate in the future, making the training unrealistic and leading to poor generalization.

Thanks for the reply @elirod

I understand the general notion of masking the input for training causal models, but for the specific example in the course the context is a 5-dimensional one-hot encoded vector representing the sprite category, it’s not a sequence of tokens where you need to hide the next tokens.

The mask actually is random, and effectively introduces some noise to the context vector.

Still wondering why this is needed?

this is probably generated by chatgpt

What is generated by chatgpt? I don’t get it.

He is referring to my answer regards your question

Hi @Yacine_Mazari,

The context mask is used so as to be able to sample unconditionally (in addition to the conditional sampling already incorporated with the context).

The context mask as implemented in the course:
context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.9)
returns a tensor of 1s and 0s. You’ll get more ones in this case due to the 0.9.

The 1s in the tensor will make the original context, c, remain as it was (conditioning) while the small proportion of 0s will make the context for the sample in that index, a tensor of 0 (unconditioning in this case).

c = c * context_mask.unsqueeze(-1)

This masking can also helps with classifier-free guidance where you can introduce a scale factor between the conditioning term and unconditional term. In my case, this helps with diversity of generated samples.


1 Like

Hi @Lawrytime

Welcome to the community.

Thanks for your contribution.