Or maybe it’s better to simplify the point by going back to first principles here. If you get a shape mismatch, then the first question is always “Ok, what are the shapes?”
print(flattened_inputs.shape)
print(flattened_outputs.shape)
What does that show? In the context of this problem, would you expect them to be the same?