Creating a GRU model using Trax

Hi @Aayush_Jariwala

I apologize for late reply.

Regarding Q1: You need ShiftRight layer so that your x_0 would be 0. In other words your sequence would start from 0 so your model would predict the next character. So instead of starting from initial hidden state and x_0 value being 97, you start with initial hidden state and x_0 value being 0 and your model should try to predict 97.

You probably could achieve the same thing by modifying data_loader function but this layer is probably superior in speed and probably there are scenarios where it is very handy.

Regarding Q2: We do not need to specify h_0 size, because it must be the same dimension as the embedding dimension (in our case model_dimension, the same size as x_0 and x_1 and …)

Regarding Q4: I think you understand it correctly. The loss is averaged over the sequence of predictions: [0] → [97], [0, 97] → [98], [0, 97, 98] → [99], [0, 97, 98, 99] → [1]

Here is example of predicting next character, when input is ‘t’:

====================================================================================
# starting with letter 't', predict up to 32 characters
#print(predict(32, "t"))  

# the input 't' is changed to integer
# inp:
# [116]

# input is changed to array
DeviceArray([116,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
               0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
               0,   0,   0,   0,   0,   0,   0,   0,   0], dtype=int32)

# added batch dimension
DeviceArray([[116,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=int32)

#input.shape
#(1, 33)

#inputs
DeviceArray([[116,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=int32)
===============================================================================
=====
# ShiftRight layer
Serial[
  ShiftRight(1)
]

# outputs
DeviceArray([[  0, 116,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
                0,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=int32)

# outputs.shape
(1, 33)

====================================================================================
# Embedding layer
Embedding_256_512

# layer weights shape
(256, 512)

# layer outputs
DeviceArray([[[ 0.04724933, -0.02226537, -0.07821855, ...,  0.08575993,
               -0.00808371,  0.02202921],
              [ 0.10795548, -0.03526021, -0.08854008, ...,  0.01529627,
                0.0868175 ,  0.09092615],
              [ 0.04724933, -0.02226537, -0.07821855, ...,  0.08575993,
               -0.00808371,  0.02202921],
              ...,
              [ 0.04724933, -0.02226537, -0.07821855, ...,  0.08575993,
               -0.00808371,  0.02202921],
              [ 0.04724933, -0.02226537, -0.07821855, ...,  0.08575993,
               -0.00808371,  0.02202921],
              [ 0.04724933, -0.02226537, -0.07821855, ...,  0.08575993,
               -0.00808371,  0.02202921]]], dtype=float32)

# note all `0` have the same embedding

# outputs.shape
(1, 33, 512)

====================================================================================
# first GRU layer
GRU_512

# layer outputs
DeviceArray([[[ 0.20887429,  0.42045066,  0.24031094, ..., -0.25590184,
               -0.20555982, -0.26257256],
              [ 0.6903972 ,  0.563846  ,  0.6010454 , ..., -0.9051461 ,
               -0.506419  , -0.46426502],
              [ 0.15103428,  0.814925  ,  0.2067267 , ..., -0.29996556,
               -0.49159542, -0.5896287 ],
              ...,
              [ 0.73863196,  0.6200977 ,  0.63184285, ..., -0.9044281 ,
               -0.08520836, -0.6938022 ],
              [ 0.73863304,  0.62009877,  0.63184524, ..., -0.9044295 ,
               -0.0852062 , -0.69380426],
              [ 0.7386341 ,  0.62009937,  0.63184786, ..., -0.9044308 ,
               -0.0852043 , -0.6938063 ]]], dtype=float32)

# note all `0` have slightly different values

# outputs.shape
(1, 32, 512)

======================================================================================
# second GRU layer
GRU_512

# outputs
DeviceArray([[[-0.02699553,  0.68844396, -0.45988116, ..., -0.64575034,
                0.99919385, -0.9338472 ],
              [-0.13541652, -0.9974171 , -0.46418995, ..., -0.65407974,
                0.9975511 , -0.9740795 ],
              [-0.62881285, -0.4357384 , -0.4716416 , ..., -0.672858  ,
                0.9999947 , -0.9998259 ],
              ...,
              [-1.        ,  0.87456065, -0.5058002 , ..., -0.7197592 ,
                0.9775919 , -0.99944305],
              [-1.        ,  0.87489724, -0.5069449 , ..., -0.72155267,
                0.9775692 , -0.99944377],
              [-1.        ,  0.8752117 , -0.50808483, ..., -0.7233412 ,
                0.9775478 , -0.99944407]]], dtype=float32)

# outputs.shape
(1, 32, 512)

======================================================================================
# Linear layer
Dense_256

# outputs
DeviceArray([[[-4.624827  , -2.8908231 , -5.1104207 , ..., -4.9862876 ,
               -4.7495956 , -5.9244623 ],
              [-3.6607375 , -0.39768016, -4.6184626 , ..., -5.850813  ,
               -4.533198  , -5.199797  ],
              [-6.8010573 , -2.9996667 , -8.04726   , ..., -6.9786253 ,
               -6.5035443 , -8.284581  ],
              ...,
              [-8.779667  , -1.2761917 , -9.385116  , ..., -7.8552837 ,
               -8.279278  , -9.4966755 ],
              [-8.781777  , -1.2657752 , -9.386472  , ..., -7.8565826 ,
               -8.281515  , -9.496905  ],
              [-8.783843  , -1.2556825 , -9.387917  , ..., -7.8577847 ,
               -8.283754  , -9.497269  ]]], dtype=float32)

# outputs.shape
(1, 32, 256)

======================================================================================
# LogSoftmax layer
LogSoftmax

# outputs
DeviceArray([[[-13.580183, -11.84618 , -14.065777, ..., -13.941645,
               -13.704952, -14.879819],
              [-15.388636, -12.125579, -16.346361, ..., -17.578712,
               -16.261097, -16.927696],
              [-17.529985, -13.728596, -18.776188, ..., -17.707554,
               -17.232473, -19.01351 ],
              ...,
              [-19.89362 , -12.390145, -20.49907 , ..., -18.969238,
               -19.39323 , -20.61063 ],
              [-19.894333, -12.378332, -20.499027, ..., -18.96914 ,
               -19.394072, -20.609463],
              [-19.895145, -12.366985, -20.49922 , ..., -18.969088,
               -19.395058, -20.608572]]], dtype=float32)

# outputs.shape
(1, 32, 256)

======================================================================================
# make a prediction according to previous outputs and added noise

# np.argmax(log_probs + g * temperature, axis=-1)

# next_char
DeviceArray(111, dtype=int32)

# add prediction to current input
# inp += [int(next_char)]
# inp
[116, 111]

# which translates to 'o'
# result.append(chr(int(next_char)))
# result
['t', 'o']

======================================================================================
======================================================================================
# the next input for the whole sequence becomes
DeviceArray([116, 111,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
               0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
               0,   0,   0,   0,   0,   0,   0,   0], dtype=int32)
1 Like