C4_W4_Ungraded_Lab_1_Reformer_LSH weight W_o

Hi there,

I am working on this lab, and notice that in the class OurSelfAttention section there is this weight w_o which is used to multiply the attention result o at the end.

# Notice, wo weight matrix applied to output of attend in forward_unbatched
out = np.matmul(o, w_o)

It is also shown in Figure 3 that there is this W_o matrix being multiplied to the attention output before concatenation.

I wonder what is this weight matrix and where does it come from? Thanks!

Hi @Longyu_Zhao

I’m not sure I understand your question fully. Could you be more specific what are your doubts about w_o weight matrix? Because this weight matrix is no different from regular transformers (those we had it in previous weeks) and it is initialized the same way (I’m pretty sure it is GlorotUniformInitializer for W and RandomNormalInitializer for b, as in the default Dense layer’s weights initialization).

https://trax-ml.readthedocs.io/en/latest/trax.layers.html#module-trax.layers.initializers

Dense layer’s __init__ function:


  def __init__(self,
               n_units,
               kernel_initializer=init.GlorotUniformInitializer(),
               bias_initializer=init.RandomNormalInitializer(1e-6),
               use_bias=True,
               use_bfloat16=False):
    """Returns a dense (fully connected) layer of width `n_units`.
    A dense layer maps collections of `R^m` vectors to `R^n`, where `n`
    (`= n_units`) is fixed at layer creation time, and `m` is set at layer
    initialization time.
    Args:
      n_units: Number of nodes in the layer, also known as the width of the
          layer.
      kernel_initializer: Function that creates a matrix of (random) initial
          connection weights `W` for the layer.
      bias_initializer: Function that creates a vector of (random) initial
          bias weights `b` for the layer.
      use_bias: If `True`, compute an affine map `y = Wx + b`; else compute
          a linear map `y = Wx`.
      use_bfloat16: If `True`, use bfloat16 weights instead of the default
        float32; this can save memory but may (rarely) lead to numerical issues.
    """