pytorch Broadcasting Selecting Samples batch[[1]] # => (1, ...) batch[1] # => (...) New Axis == None batch = torch.randn((8, 3, 12)) batch[None] # (1, 8, 3, 12) batch[..., None] # (8, 3, 12, 1) batch[:, None] # (8, 1, 3, 13)