ClassificationData.infer_on_batch#

abstract ClassificationData.infer_on_batch(batch, model, device) Tensor[source]#

Return the predictions of the model on a batch of data.

Parameters
batchtorch.Tensor

The batch of data.

modeltorch.nn.Module

The model to use for inference.

devicetorch.device

The device to use for inference.

Returns
torch.Tensor

The predictions of the model on the batch. The predictions should be in a OHE tensor format of shape (N, n_classes), where N is the number of samples in the batch.

Notes

The accepted prediction format for classification is a tensor of shape (N, n_classes), where N is the number of samples. Each element is an array of length n_classes that represent the probability of each class.

Examples

>>> import torch.nn.functional as F
...
...
... def infer_on_batch(self, batch, model, device):
...     logits = model.to(device)(batch[0].to(device))
...     return F.softmax(logits, dim=1)