Creating a GRU model using Trax

GRU = tl.Serial(
      tl.ShiftRight(mode=mode), # Do remember to pass the mode parameter if you are using it for interence/test as default is train 
      tl.Embedding(vocab_size=vocab_size, d_feature=model_dimension),
      [tl.GRU(n_units=model_dimension) for _ in range(n_layers)], # You can play around n_layers if you want to stack more GRU layers together
      tl.Dense(n_units=vocab_size),
      tl.LogSoftmax()
    )

Q-1: What is ShiftRIght()? Why is it needed? (Can you explain me with simple example code which includes numpy array) Later in week 3 we are building LSTM model using same method except ShiftRight() layer. Why LSTM doesn’t require ShiftRight() layer?

Q-2: At line 4, don’t we need to provide number of hidden units? ( in tl.GRU() method) How will model decide that thing?

Q-3: What is meaning of stacking GRUs? (Line 4) Yes, I know GRU would take d_feature vector and output d_feature vector but why we need to stack them? Isn’t single GRU enough to predict next word?

Q-4: In any ML problem we have to make X and Y. But in this text-generation-algorithm we are not creating Y at all? (In a sequence for every word next word should be consider as Y, but we are not creating Y_labels in here. Why?)

Hi @Aayush_Jariwala

These are good questions and I will try my best :slight_smile: someone might correct me where I’m wrong.

Q-1: ShiftRight Shifts the tensor to the right by padding with zeros on axis 1. As you asked it is better to explain with an example:

from trax import layers as tl
import numpy as np

arr = np.array([[1, 2, 3],
                [4, 5, 6]])

tl.ShiftRight(n_positions=1, mode='train')(arr)

"""
Output:
DeviceArray([[0, 1, 2],
             [0, 4, 5]], dtype=int32)
"""

# the same thing with numpy
arr2 = np.array([[0], 
		         [0]])

np.concatenate([arr2, arr[:, :-1]], axis=1)
"""

Output:
array([[0, 1, 2],
       [0, 4, 5]])
"""

In this week it is used to predict the next character - in modes train/eval it shifts right, and in predict mode it does not. The goal is to train the model to predict the next character. Later in week 3, it is not used because we are predicting something different (NERs).

Q-2: I am not sure I understand because we do provide the number of hidden units n_units=model_dimension

Q-3: Stacking layers is the “deep” part in learning. Very loosely speaking, one unit could be tracking gender, another singular/plural, yet another something else. And stacking them on top of each other you could imagine hierarchical dependencies which would “make sense” like plural female is “women” etc. But in reality it’s just weight matrices, which fit best to the loss (objective) function and sometimes it happens to be, that some layer actually starts to track gender or smth. Check this famous blog post (especially the “Visualizing the predictions and the “neuron” firings in the RNN”)

Q-4: Again, here Y’s are the next characters we are trying to predict.

1 Like

Referring to Q-1: That’s really nice explanation but question still remains. Why ShiftRight() even needed?
Suppose my “batch data generator” generated as follow:
[97, 98, 99, 1]
[100, 101, 102, 1]
Based on dataset with max_length=4:
[‘abc’,
‘xyz’]

As per maths I understood, first it will take 97 and convert it to EMBEDDING with ‘d_feature’ size.
Then it will go into GRU layer and produce output in DENSE layer followed by LOGSOFTMAX.

Based on current dataset Logsoftmax should have produced highest value for ‘class 98’. Back-propagation would take care of adjusting weights in whole model.

Then it will do same procedure but now input is 98 and output is ‘class 99’.
After completing first line it will then repeat same procedure for line [100, 101, 102, 1]

In this whole process I don’t know where ShiftRight() requires!

Why I need to do something like this?
[97, 98, 99, 1][0, 97, 98, 99]
[100, 101, 102, 1][0, 100, 101, 102]

Referring to Q-2:
GRU architecture

Why we are not specifying size of h<t_0>? For e.g. 16 is a vector size I want for my GRU.
Just like tl.GRU(n_units=model_dimension) we are specifying size of x<t_1>

Referring to Q-4:
So can I assume, that model will automatically create Y=98 if X=97 from the line [97, 98, 99, 1] and will traverse till X=99 and Y=1?

Hi @Aayush_Jariwala

I apologize for late reply.

Regarding Q1: You need ShiftRight layer so that your x_0 would be 0. In other words your sequence would start from 0 so your model would predict the next character. So instead of starting from initial hidden state and x_0 value being 97, you start with initial hidden state and x_0 value being 0 and your model should try to predict 97.

You probably could achieve the same thing by modifying data_loader function but this layer is probably superior in speed and probably there are scenarios where it is very handy.

Regarding Q2: We do not need to specify h_0 size, because it must be the same dimension as the embedding dimension (in our case model_dimension, the same size as x_0 and x_1 and …)

Regarding Q4: I think you understand it correctly. The loss is averaged over the sequence of predictions: [0] → [97], [0, 97] → [98], [0, 97, 98] → [99], [0, 97, 98, 99] → [1]

Here is example of predicting next character, when input is ‘t’:

====================================================================================
# starting with letter 't', predict up to 32 characters
#print(predict(32, "t"))  

# the input 't' is changed to integer
# inp:
# [116]

# input is changed to array
DeviceArray([116,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
               0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
               0,   0,   0,   0,   0,   0,   0,   0,   0], dtype=int32)

# added batch dimension
DeviceArray([[116,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=int32)

#input.shape
#(1, 33)

#inputs
DeviceArray([[116,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=int32)
===============================================================================
=====
# ShiftRight layer
Serial[
  ShiftRight(1)
]

# outputs
DeviceArray([[  0, 116,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=int32)

# outputs.shape
(1, 33)

====================================================================================
# Embedding layer
Embedding_256_512

# layer weights shape
(256, 512)

# layer outputs
DeviceArray([[[ 0.04724933, -0.02226537, -0.07821855, ...,  0.08575993,
               -0.00808371,  0.02202921],
              [ 0.10795548, -0.03526021, -0.08854008, ...,  0.01529627,
                0.0868175 ,  0.09092615],
              [ 0.04724933, -0.02226537, -0.07821855, ...,  0.08575993,
               -0.00808371,  0.02202921],
              ...,
              [ 0.04724933, -0.02226537, -0.07821855, ...,  0.08575993,
               -0.00808371,  0.02202921],
              [ 0.04724933, -0.02226537, -0.07821855, ...,  0.08575993,
               -0.00808371,  0.02202921],
              [ 0.04724933, -0.02226537, -0.07821855, ...,  0.08575993,
               -0.00808371,  0.02202921]]], dtype=float32)

# note all `0` have the same embedding

# outputs.shape
(1, 33, 512)

====================================================================================
# first GRU layer
GRU_512

# layer outputs
DeviceArray([[[ 0.20887429,  0.42045066,  0.24031094, ..., -0.25590184,
               -0.20555982, -0.26257256],
              [ 0.6903972 ,  0.563846  ,  0.6010454 , ..., -0.9051461 ,
               -0.506419  , -0.46426502],
              [ 0.15103428,  0.814925  ,  0.2067267 , ..., -0.29996556,
               -0.49159542, -0.5896287 ],
              ...,
              [ 0.73863196,  0.6200977 ,  0.63184285, ..., -0.9044281 ,
               -0.08520836, -0.6938022 ],
              [ 0.73863304,  0.62009877,  0.63184524, ..., -0.9044295 ,
               -0.0852062 , -0.69380426],
              [ 0.7386341 ,  0.62009937,  0.63184786, ..., -0.9044308 ,
               -0.0852043 , -0.6938063 ]]], dtype=float32)

# note all `0` have slightly different values

# outputs.shape
(1, 32, 512)

======================================================================================
# second GRU layer
GRU_512

# outputs
DeviceArray([[[-0.02699553,  0.68844396, -0.45988116, ..., -0.64575034,
                0.99919385, -0.9338472 ],
              [-0.13541652, -0.9974171 , -0.46418995, ..., -0.65407974,
                0.9975511 , -0.9740795 ],
              [-0.62881285, -0.4357384 , -0.4716416 , ..., -0.672858  ,
                0.9999947 , -0.9998259 ],
              ...,
              [-1.        ,  0.87456065, -0.5058002 , ..., -0.7197592 ,
                0.9775919 , -0.99944305],
              [-1.        ,  0.87489724, -0.5069449 , ..., -0.72155267,
                0.9775692 , -0.99944377],
              [-1.        ,  0.8752117 , -0.50808483, ..., -0.7233412 ,
                0.9775478 , -0.99944407]]], dtype=float32)

# outputs.shape
(1, 32, 512)

======================================================================================
# Linear layer
Dense_256

# outputs
DeviceArray([[[-4.624827  , -2.8908231 , -5.1104207 , ..., -4.9862876 ,
               -4.7495956 , -5.9244623 ],
              [-3.6607375 , -0.39768016, -4.6184626 , ..., -5.850813  ,
               -4.533198  , -5.199797  ],
              [-6.8010573 , -2.9996667 , -8.04726   , ..., -6.9786253 ,
               -6.5035443 , -8.284581  ],
              ...,
              [-8.779667  , -1.2761917 , -9.385116  , ..., -7.8552837 ,
               -8.279278  , -9.4966755 ],
              [-8.781777  , -1.2657752 , -9.386472  , ..., -7.8565826 ,
               -8.281515  , -9.496905  ],
              [-8.783843  , -1.2556825 , -9.387917  , ..., -7.8577847 ,
               -8.283754  , -9.497269  ]]], dtype=float32)

# outputs.shape
(1, 32, 256)

======================================================================================
# LogSoftmax layer
LogSoftmax

# outputs
DeviceArray([[[-13.580183, -11.84618 , -14.065777, ..., -13.941645,
               -13.704952, -14.879819],
              [-15.388636, -12.125579, -16.346361, ..., -17.578712,
               -16.261097, -16.927696],
              [-17.529985, -13.728596, -18.776188, ..., -17.707554,
               -17.232473, -19.01351 ],
              ...,
              [-19.89362 , -12.390145, -20.49907 , ..., -18.969238,
               -19.39323 , -20.61063 ],
              [-19.894333, -12.378332, -20.499027, ..., -18.96914 ,
               -19.394072, -20.609463],
              [-19.895145, -12.366985, -20.49922 , ..., -18.969088,
               -19.395058, -20.608572]]], dtype=float32)

# outputs.shape
(1, 32, 256)

======================================================================================
# make a prediction according to previous outputs and added noise

# np.argmax(log_probs + g * temperature, axis=-1)

# next_char
DeviceArray(111, dtype=int32)

# add prediction to current input
# inp += [int(next_char)]
# inp
[116, 111]

# which translates to 'o'
# result.append(chr(int(next_char)))
# result
['t', 'o']

======================================================================================
======================================================================================
# the next input for the whole sequence becomes
DeviceArray([116, 111,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
               0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
               0,   0,   0,   0,   0,   0,   0,   0], dtype=int32)
1 Like