DetectionData.batch_to_labels#
- abstract DetectionData.batch_to_labels(batch) List[Tensor] [source]#
Extract the labels from a batch of data.
- Parameters
- batchtorch.Tensor
The batch of data.
- Returns
- List[torch.Tensor]
The labels extracted from the batch. The labels should be a list of length N containing tensor of shape (B, 5) where N is the number of samples, B is the number of bounding boxes in the sample and each bounding box is represented by 5 values. See the notes for more info.
Notes
The accepted label format for is a a list of length N containing tensors of shape (B, 5), where N is the number of samples, B is the number of bounding boxes in the sample and each bounding box is represented by 5 values: (class_id, x, y, w, h). x and y are the coordinates (in pixels) of the upper left corner of the bounding box, w
and h are the width and height of the bounding box (in pixels) and class_id is the class id of the prediction.
Examples
>>> import torch ... ... ... def batch_to_labels(self, batch): ... # each bbox in the labels is (class_id, x, y, x, y). convert to (class_id, x, y, w, h) ... return [torch.stack( ... [torch.cat((bbox[0], bbox[1:3], bbox[4:] - bbox[1:3]), dim=0) ... for bbox in image]) ... for image in batch[1]]