How does trax word embedding layer work?

The Embedding code is very simple - Embedding code

As you can see it is just one line of code for forward propagation:

jnp.take(self.weights, x, axis=0)

What it does, it just “takes” the x’th rows from the weight matrix (self.weights). So if you have Embedding matrix with vocabulary of length 20 and embedding dimension 4 (shape (20, 4)):
image

and you pass your batch of two sentences (for example, x of shape (2, 4):
image

the Embedding layer will return you shape (2, 4, 4) adds one dimension - the Embedding size dimension:
image

This is all that it does - it takes input of shape*(batch_size, seqlen)* and outputs the shape (batch_size, seqlen, emb_size).

1 Like