ClassificationData.batch_to_labels#
- abstract ClassificationData.batch_to_labels(batch) Tensor [source]#
Extract the labels from a batch of data.
- Parameters
- batchtorch.Tensor
The batch of data.
- Returns
- torch.Tensor
The labels extracted from the batch. The labels should be in a tensor format of shape (N,), where N is the number of samples in the batch. See the notes for more info.
Notes
The accepted label format for classification is a tensor of shape (N,), when N is the number of samples. Each element is an integer representing the class index.
Examples
>>> def batch_to_labels(self, batch): ... return batch[1]