The Embedding code is very simple - Embedding code
As you can see it is just one line of code for forward propagation:
jnp.take(self.weights, x, axis=0)
What it does, it just “takes” the x’th rows from the weight matrix (self.weights
). So if you have Embedding matrix with vocabulary of length 20 and embedding dimension 4 (shape (20, 4)):
and you pass your batch of two sentences (for example, x
of shape (2, 4):
the Embedding layer will return you shape (2, 4, 4) adds one dimension - the Embedding size dimension:
This is all that it does - it takes input of shape*(batch_size, seqlen)* and outputs the shape (batch_size, seqlen, emb_size).