DetectionData.infer_on_batch#
- abstract DetectionData.infer_on_batch(batch, model, device) List[Tensor] [source]#
Return the predictions of the model on a batch of data.
- Parameters
- batchtorch.Tensor
The batch of data.
- modeltorch.nn.Module
The model to use for inference.
- devicetorch.device
The device to use for inference.
- Returns
- List[torch.Tensor]
The predictions of the model on the batch. The predictions should be in a List of length N containing tensors of shape (B, 6), where N is the number of images, B is the number of bounding boxes detected in the sample and each bounding box is represented by 6 values. See the notes for more info.
Notes
The accepted prediction format is a list of length N containing tensors of shape (B, 6), where N is the number of images, B is the number of bounding boxes detected in the sample and each bounding box is represented by 6 values: [x, y, w, h, confidence, class_id]. x and y are the coordinates (in pixels) of the upper left corner of the bounding box, w and h are the width and height of the bounding box (in pixels), confidence is the confidence of the model and class_id is the class id.
Examples
>>> import torch ... ... ... def infer_on_batch(self, batch, model, device): ... # Converts a yolo prediction batch to the accepted xywh format ... return_list = [] ... ... predictions = model(batch[0]) ... # yolo Detections objects have List[torch.Tensor] xyxy output in .pred ... for single_image_tensor in predictions.pred: ... pred_modified = torch.clone(single_image_tensor) ... pred_modified[:, 2] = pred_modified[:, 2] - pred_modified[:, 0] ... pred_modified[:, 3] = pred_modified[:, 3] - pred_modified[:, 1] ... return_list.append(pred_modified) ... ... return return_list