DetectionData.batch_to_labels#

abstract DetectionData.batch_to_labels(batch) List[Tensor][source]#

Extract the labels from a batch of data.

Parameters
batchtorch.Tensor

The batch of data.

Returns
List[torch.Tensor]

The labels extracted from the batch. The labels should be a list of length N containing tensor of shape (B, 5) where N is the number of samples, B is the number of bounding boxes in the sample and each bounding box is represented by 5 values. See the notes for more info.

Notes

The accepted label format for is a a list of length N containing tensors of shape (B, 5), where N is the number of samples, B is the number of bounding boxes in the sample and each bounding box is represented by 5 values: (class_id, x, y, w, h). x and y are the coordinates (in pixels) of the upper left corner of the bounding box, w

and h are the width and height of the bounding box (in pixels) and class_id is the class id of the prediction.

Examples

>>> import torch
...
...
... def batch_to_labels(self, batch):
...     # each bbox in the labels is (class_id, x, y, x, y). convert to (class_id, x, y, w, h)
...     return [torch.stack(
...            [torch.cat((bbox[0], bbox[1:3], bbox[4:] - bbox[1:3]), dim=0)
...                for bbox in image])
...             for image in batch[1]]