Hi @Aayush_Jariwala
I apologize for late reply.
Regarding Q1: You need ShiftRight
layer so that your x_0 would be 0. In other words your sequence would start from 0 so your model would predict the next character. So instead of starting from initial hidden state and x_0 value being 97, you start with initial hidden state and x_0 value being 0 and your model should try to predict 97.
You probably could achieve the same thing by modifying data_loader
function but this layer is probably superior in speed and probably there are scenarios where it is very handy.
Regarding Q2: We do not need to specify h_0 size, because it must be the same dimension as the embedding dimension (in our case model_dimension
, the same size as x_0 and x_1 and âŚ)
Regarding Q4: I think you understand it correctly. The loss is averaged over the sequence of predictions: [0] â [97], [0, 97] â [98], [0, 97, 98] â [99], [0, 97, 98, 99] â [1]
Here is example of predicting next character, when input is âtâ:
====================================================================================
# starting with letter 't', predict up to 32 characters
#print(predict(32, "t"))
# the input 't' is changed to integer
# inp:
# [116]
# input is changed to array
DeviceArray([116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)
# added batch dimension
DeviceArray([[116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)
#input.shape
#(1, 33)
#inputs
DeviceArray([[116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)
===============================================================================
=====
# ShiftRight layer
Serial[
ShiftRight(1)
]
# outputs
DeviceArray([[ 0, 116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)
# outputs.shape
(1, 33)
====================================================================================
# Embedding layer
Embedding_256_512
# layer weights shape
(256, 512)
# layer outputs
DeviceArray([[[ 0.04724933, -0.02226537, -0.07821855, ..., 0.08575993,
-0.00808371, 0.02202921],
[ 0.10795548, -0.03526021, -0.08854008, ..., 0.01529627,
0.0868175 , 0.09092615],
[ 0.04724933, -0.02226537, -0.07821855, ..., 0.08575993,
-0.00808371, 0.02202921],
...,
[ 0.04724933, -0.02226537, -0.07821855, ..., 0.08575993,
-0.00808371, 0.02202921],
[ 0.04724933, -0.02226537, -0.07821855, ..., 0.08575993,
-0.00808371, 0.02202921],
[ 0.04724933, -0.02226537, -0.07821855, ..., 0.08575993,
-0.00808371, 0.02202921]]], dtype=float32)
# note all `0` have the same embedding
# outputs.shape
(1, 33, 512)
====================================================================================
# first GRU layer
GRU_512
# layer outputs
DeviceArray([[[ 0.20887429, 0.42045066, 0.24031094, ..., -0.25590184,
-0.20555982, -0.26257256],
[ 0.6903972 , 0.563846 , 0.6010454 , ..., -0.9051461 ,
-0.506419 , -0.46426502],
[ 0.15103428, 0.814925 , 0.2067267 , ..., -0.29996556,
-0.49159542, -0.5896287 ],
...,
[ 0.73863196, 0.6200977 , 0.63184285, ..., -0.9044281 ,
-0.08520836, -0.6938022 ],
[ 0.73863304, 0.62009877, 0.63184524, ..., -0.9044295 ,
-0.0852062 , -0.69380426],
[ 0.7386341 , 0.62009937, 0.63184786, ..., -0.9044308 ,
-0.0852043 , -0.6938063 ]]], dtype=float32)
# note all `0` have slightly different values
# outputs.shape
(1, 32, 512)
======================================================================================
# second GRU layer
GRU_512
# outputs
DeviceArray([[[-0.02699553, 0.68844396, -0.45988116, ..., -0.64575034,
0.99919385, -0.9338472 ],
[-0.13541652, -0.9974171 , -0.46418995, ..., -0.65407974,
0.9975511 , -0.9740795 ],
[-0.62881285, -0.4357384 , -0.4716416 , ..., -0.672858 ,
0.9999947 , -0.9998259 ],
...,
[-1. , 0.87456065, -0.5058002 , ..., -0.7197592 ,
0.9775919 , -0.99944305],
[-1. , 0.87489724, -0.5069449 , ..., -0.72155267,
0.9775692 , -0.99944377],
[-1. , 0.8752117 , -0.50808483, ..., -0.7233412 ,
0.9775478 , -0.99944407]]], dtype=float32)
# outputs.shape
(1, 32, 512)
======================================================================================
# Linear layer
Dense_256
# outputs
DeviceArray([[[-4.624827 , -2.8908231 , -5.1104207 , ..., -4.9862876 ,
-4.7495956 , -5.9244623 ],
[-3.6607375 , -0.39768016, -4.6184626 , ..., -5.850813 ,
-4.533198 , -5.199797 ],
[-6.8010573 , -2.9996667 , -8.04726 , ..., -6.9786253 ,
-6.5035443 , -8.284581 ],
...,
[-8.779667 , -1.2761917 , -9.385116 , ..., -7.8552837 ,
-8.279278 , -9.4966755 ],
[-8.781777 , -1.2657752 , -9.386472 , ..., -7.8565826 ,
-8.281515 , -9.496905 ],
[-8.783843 , -1.2556825 , -9.387917 , ..., -7.8577847 ,
-8.283754 , -9.497269 ]]], dtype=float32)
# outputs.shape
(1, 32, 256)
======================================================================================
# LogSoftmax layer
LogSoftmax
# outputs
DeviceArray([[[-13.580183, -11.84618 , -14.065777, ..., -13.941645,
-13.704952, -14.879819],
[-15.388636, -12.125579, -16.346361, ..., -17.578712,
-16.261097, -16.927696],
[-17.529985, -13.728596, -18.776188, ..., -17.707554,
-17.232473, -19.01351 ],
...,
[-19.89362 , -12.390145, -20.49907 , ..., -18.969238,
-19.39323 , -20.61063 ],
[-19.894333, -12.378332, -20.499027, ..., -18.96914 ,
-19.394072, -20.609463],
[-19.895145, -12.366985, -20.49922 , ..., -18.969088,
-19.395058, -20.608572]]], dtype=float32)
# outputs.shape
(1, 32, 256)
======================================================================================
# make a prediction according to previous outputs and added noise
# np.argmax(log_probs + g * temperature, axis=-1)
# next_char
DeviceArray(111, dtype=int32)
# add prediction to current input
# inp += [int(next_char)]
# inp
[116, 111]
# which translates to 'o'
# result.append(chr(int(next_char)))
# result
['t', 'o']
======================================================================================
======================================================================================
# the next input for the whole sequence becomes
DeviceArray([116, 111, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)