The Semantic Segmentation Data Class#
The Segmentation is a data class designed for semantic segmentation tasks.
It is a subclass of the VisionData
class and is used to help deepchecks load and interact with semantic segmentation data using a well defined format.
detection related checks.
For more info, please visit the API reference page: SegmentationData
Accepted Image Format#
All checks in deepchecks require images in the same format. They use the batch_to_images()
function in order to get
the images in the correct format. For more info on the accepted formats, please visit the
VisionData User Guide.
Accepted Label Format#
Deepchecks’ checks use the batch_to_labels()
function in order to get the labels in the correct format.
The accepted label format is a list of length N containing
tensors of shape (H, W), where N is the number of images, and H and W are the height and width of the
corresponding image, and its values are the true class_ids of the corresponding pixels in that image.
Note that the tensor should 2D, as the number of channels on the original image are irrelevant to the class.
Accepted Prediction Format#
Deepchecks’ checks use the infer_on_batch()
function in order to get the predictions of the model in the correct format.
The accepted prediction format is a list of length N containing
tensors of shape (C, H, W), where N is the number of images, H and W are the height and width of the
corresponding image, and C is the number of classes that can be detected, each channel corresponds to a
class_id.
Note that the values of dimension C are the probabilities for each class and should sum to 1.
Example#
Assuming we have implemented a torch DataLoader whose underlying __getitem__ method returns a tuple of the form:
(images, labels)
. images
is a tensor of shape (N, C, H, W) in which the images pixel values are normalized to
[0, 1] range based on the mean and std of the ImageNet dataset. labels
is a tensor of shape (N, H, W) in which
each pixel is an integer correlating with the relevant class_id.
from deepchecks.vision import SegmentationData
import torch
import numpy as np
class MySegmentationTaskData(DetectionData)
"""A deepchecks data digestion class for object detection related checks."""
def batch_to_images(self, batch):
"""Convert a batch of images to a list of PIL images.
Parameters
----------
batch : torch.Tensor
The batch of images to convert.
Returns
-------
list
A list of PIL images.
"""
# Assuming batch[0] is a batch of (N, C, H, W) images, we convert it to (N, H, W, C)/
imgs = batch[0].detach().numpy().transpose((0, 2, 3, 1))
# The images are normalized to [0, 1] range based on the mean and std of the ImageNet dataset, so we need to
# convert them back to [0, 255] range.
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
imgs = std * imgs + mean
imgs = np.clip(imgs, 0, 1)
imgs *= 255
return imgs
def batch_to_labels(self, batch):
"""Convert a batch of labels to the required format.
Parameters
----------
batch : tuple
The batch of data, containing images and labels.
Returns
-------
List
A list of size N containing tensors of shape (H, W).
"""
# In this example, each image's label is a tensor of boolean masks, one per class_id, indicating whether
# that pixel is of that class.
# We would like to convert to a format where the function returns a single mask indicating the exact
# of each pixel:
images = batch[0]
labels = batch[1]
return_labels = []
for label, image in zip(images, labels):
# Here, class_id "0" is "background" or "no class detected"
ret_label = np.zeros((image.shape[0], image.shape[1]))
# Mask to mark which pixels are already identified as classes, in case of overlap in boolean masks
ret_label_taken_positions = np.zeros(ret_label.shape)
# Go over all masks of this image and transform them to a single one:
for i in range(len(label)):
mask = np.logical_and(np.logical_not(ret_label_taken_positions), np.array(label[i]))
ret_label += i * mask
# Update the taken positions:
ret_label_taken_positions = np.logical_or(ret_label_taken_positions, mask)
return_labels.append(ret_label)
return return_labels
def infer_on_batch(self, batch, model, device):
"""Get the predictions of the model on a batch of images.
Parameters
----------
batch : tuple
The batch of data, containing images and labels.
model : torch.nn.Module
The model to use for inference.
device : torch.device
The device to use for inference.
Returns
-------
List
A list of size N containing tensors of shape (C, H, W).
"""
# Converts prediction received as (H, W, C) format to (C, H, W) format:
return_list = []
predictions = model(batch[0])
for single_image_tensor in predictions:
single_image_tensor = torch.transpose(single_image_tensor, 0, 2)
single_image_tensor = torch.transpose(single_image_tensor, 1, 2)
return_list.append(single_image_tensor)
return return_list
# Now, in order to test the class, we can create an instance of it:
data = MySegmentationTaskData(your_dataloader)
# And validate the implementation:
data.validate_format(your_model)