Tensor Tricks

#pytorch

Broadcasting



Selecting Samples


batch[[1]]  # => (1, ...)
batch[1]  # => (...)

New Axis == None

batch = torch.randn((8, 3, 12))

batch[None]  # (1, 8, 3, 12)

batch[..., None]  # (8, 3, 12, 1)

batch[:, None]  # (8, 1, 3, 13)