Week 1 - Jazz, music_inference_model, one hot encoding

When I run the following code from step 2.D to 2.E in music_inference_model funcion:
*# Step 2.D: *

  •    # Select the next value according to "out",*
    
  •    # Set "x" to be the one-hot representation of the selected value*
    
  •    # See instructions above.*
    
  •    print("A")*
    
  •    print(x)*
    
  •    x = tf.math.argmax(input=x, axis=-1)*
    
  •    print("B")*
    
  •    print(x)*
    
  •    x = tf.one_hot(indices=x, depth=90)*
    
  •    # x = tf.keras.layers.Flatten()(tf.one_hot(indices=x, depth=90))*
    
  •    print("C")*
    
  •    print(x)*
    
  •    # Step 2.E: *
    
  •    # Use RepeatVector(1) to convert x into a tensor with shape=(None, 1, 90)*
    
  •    x = RepeatVector(n=1)(x)*
    
  •    print("D")*
    
  •    print(x)*
    

I get the following printings:
A
Tensor(“input_29:0”, shape=(None, 1, 90), dtype=float32)
B
Tensor(“ArgMax_465:0”, shape=(None, 1), dtype=int64)
C
Tensor(“OneHot_462:0”, shape=(None, 1, 90), dtype=float32)

You can observe how the output of the one hot encoding layer is of shape (None, 1, 90), when in the instructions it says it should be (None, 90) - consequently need the RepeatVector. But in this case it is not, and thus RepeatVector fails with following error:
ValueError: Input 0 of layer repeat_vector_309 is incompatible with the layer: expected ndim=2, found ndim=3. Full shape received: [None, 1, 90]

To overcome that issue, I flattened the output of one_hot layer doing:
x = tf.keras.layers.Flatten()(tf.one_hot(indices=x, depth=90))

This makes the code to work, however the test in comparator(inference_summary, music_inference_model_out) fails:
AttributeError: The layer "lstm has multiple inbound nodes, with different input shapes. Hence the notion of “input shape” is ill-defined for the layer. Use get_input_shape_at(node_index) instead.

Please could you help me in both making the code to work and ensuring that the comparator does not return any error, to be able to pass the assigment? Thanks!

Solved issue ^ (I was passing x instead of out to tf.math.argmax