DetectionData.infer_on_batch#

abstract DetectionData.infer_on_batch(batch, model, device) List[Tensor][source]#

Return the predictions of the model on a batch of data.

Parameters
batchtorch.Tensor

The batch of data.

modeltorch.nn.Module

The model to use for inference.

devicetorch.device

The device to use for inference.

Returns
List[torch.Tensor]

The predictions of the model on the batch. The predictions should be in a List of length N containing tensors of shape (B, 6), where N is the number of images, B is the number of bounding boxes detected in the sample and each bounding box is represented by 6 values. See the notes for more info.

Notes

The accepted prediction format is a list of length N containing tensors of shape (B, 6), where N is the number of images, B is the number of bounding boxes detected in the sample and each bounding box is represented by 6 values: [x, y, w, h, confidence, class_id]. x and y are the coordinates (in pixels) of the upper left corner of the bounding box, w and h are the width and height of the bounding box (in pixels), confidence is the confidence of the model and class_id is the class id.

Examples

>>> import torch
...
...
... def infer_on_batch(self, batch, model, device):
...     # Converts a yolo prediction batch to the accepted xywh format
...     return_list = []
...
...     predictions = model(batch[0])
...     # yolo Detections objects have List[torch.Tensor] xyxy output in .pred
...     for single_image_tensor in predictions.pred:
...         pred_modified = torch.clone(single_image_tensor)
...         pred_modified[:, 2] = pred_modified[:, 2] - pred_modified[:, 0]
...         pred_modified[:, 3] = pred_modified[:, 3] - pred_modified[:, 1]
...         return_list.append(pred_modified)
...
...     return return_list