Clarification about tf.transpose perm

I can not understand how the parameter of tf.transpose perm works.


perm = [0, 3, 1, 2] to the shape m * n_H * n_W * n_C should give:

  • m goes to zero position
  • n_H goes to the third position
  • n_W goes to the first position
  • n_C goes to the second position

Thus we should get


However, tensorflow gives out the different answer.

perm specifies the order you arrange your old axes into the new ones.

So in case of [0,3,1,2] (if we count the axes from 0) then:
new axis 0 ← old 0
new axis 1 ← old 3
new axis 2 ← old 1
new axis 3 ← old 2

in particular, in the beginning you have (m, nh, nw, nc),
axis 0 doesn’t change, new axis 1 will be the old axis 3, which is nc, and so on

1 Like