ClassificationData.batch_to_labels#

abstract ClassificationData.batch_to_labels(batch) Tensor[source]#

Extract the labels from a batch of data.

Parameters
batchtorch.Tensor

The batch of data.

Returns
torch.Tensor

The labels extracted from the batch. The labels should be in a tensor format of shape (N,), where N is the number of samples in the batch. See the notes for more info.

Notes

The accepted label format for classification is a tensor of shape (N,), when N is the number of samples. Each element is an integer representing the class index.

Examples

>>> def batch_to_labels(self, batch):
...     return batch[1]