ClassificationData.infer_on_batch#
- abstract ClassificationData.infer_on_batch(batch, model, device) Tensor [source]#
Return the predictions of the model on a batch of data.
- Parameters
- batchtorch.Tensor
The batch of data.
- modeltorch.nn.Module
The model to use for inference.
- devicetorch.device
The device to use for inference.
- Returns
- torch.Tensor
The predictions of the model on the batch. The predictions should be in a OHE tensor format of shape (N, n_classes), where N is the number of samples in the batch.
Notes
The accepted prediction format for classification is a tensor of shape (N, n_classes), where N is the number of samples. Each element is an array of length n_classes that represent the probability of each class.
Examples
>>> import torch.nn.functional as F ... ... ... def infer_on_batch(self, batch, model, device): ... logits = model.to(device)(batch[0].to(device)) ... return F.softmax(logits, dim=1)