Swap Axes

import numpy as np
x = np.array([
    [
        [0,1],[2,3]
    ],
    [
        [4,5],[6,7]
    ]
])
x
array([[[0, 1],
        [2, 3]],

       [[4, 5],
        [6, 7]]])
x.shape
(2, 2, 2)
y = np.swapaxes(x, 0, 2)
y
array([[[0, 4],
        [2, 6]],

       [[1, 5],
        [3, 7]]])
y.shape
(2, 2, 2)
a = np.ones((1, 2, 3, 4))
a.shape
(1, 2, 3, 4)
b = np.swapaxes(a, 3, 1)
b.shape
(1, 4, 3, 2)
b
array([[[[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.]]]])