.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "user-guide/vision/auto_tutorials/plot_classification_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_user-guide_vision_auto_tutorials_plot_classification_tutorial.py: ============================================== Classification Model Validation Tutorial ============================================== In this tutorial, you will learn how to validate your **classification model** using deepchecks test suites. You can read more about the different checks and suites for computer vision use cases at the :doc:`examples section ` A classification model is usually used to classify an image into one of a number of classes. Although there are multi label use-cases, in which the model is used to classify an image into multiple classes, most use-cases require the model to classify images into a single class. Currently deepchecks supports only single label classification (either binary or multi-class). .. GENERATED FROM PYTHON SOURCE LINES 17-19 Defining the data and model =========================== .. GENERATED FROM PYTHON SOURCE LINES 19-40 .. code-block:: default import os import urllib.request import zipfile import albumentations as A import cv2 import matplotlib.pyplot as plt import numpy as np import PIL.Image import torch # Importing the required packages import torchvision from albumentations.pytorch import ToTensorV2 from torch import nn from torchvision import datasets, models, transforms from torchvision.datasets import ImageFolder import deepchecks from deepchecks.vision.classification_data import ClassificationData .. GENERATED FROM PYTHON SOURCE LINES 41-44 Downloading the dataset ~~~~~~~~~~~~~~~~~~~~~~~ The data is available from the torch library. We will download and extract it to the current directory. .. GENERATED FROM PYTHON SOURCE LINES 44-50 .. code-block:: default url = 'https://download.pytorch.org/tutorial/hymenoptera_data.zip' urllib.request.urlretrieve(url, 'hymenoptera_data.zip') with zipfile.ZipFile('hymenoptera_data.zip', 'r') as zip_ref: zip_ref.extractall('.') .. GENERATED FROM PYTHON SOURCE LINES 51-59 Load Data ~~~~~~~~~ We will use torchvision and torch.utils.data packages for loading the data. The model we are building will learn to classify **ants** and **bees**. We have about 120 training images each for ants and bees. There are 75 validation images for each class. This dataset is a very small subset of imagenet. .. GENERATED FROM PYTHON SOURCE LINES 59-133 .. code-block:: default class AntsBeesDataset(ImageFolder): def __init__(self, *args, **kwargs): """ Overrides initialization method to replace default loader with OpenCV loader :param args: :param kwargs: """ super(AntsBeesDataset, self).__init__(*args, **kwargs) def __getitem__(self, index: int): """ overrides __getitem__ to be compatible to albumentations Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) sample = self.get_cv2_image(sample) if self.transforms is not None: transformed = self.transforms(image=sample, target=target) sample, target = transformed["image"], transformed["target"] else: if self.transform is not None: sample = self.transform(image=sample)['image'] if self.target_transform is not None: target = self.target_transform(target) return sample, target def get_cv2_image(self, image): if isinstance(image, PIL.Image.Image): image_np = np.array(image).astype('uint8') return image_np elif isinstance(image, np.ndarray): return image else: raise RuntimeError("Only PIL.Image and CV2 loaders currently supported!") # Just normalization for validation data_transforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) data_dir = 'hymenoptera_data' # Just normalization for validation data_transforms = A.Compose([ A.Resize(height=256, width=256), A.CenterCrop(height=224, width=224), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ]) train_dataset = AntsBeesDataset(root=os.path.join(data_dir,'train')) train_dataset.transforms = data_transforms val_dataset = AntsBeesDataset(root=os.path.join(data_dir,'val')) val_dataset.transforms = data_transforms dataloaders = { 'train':torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True), 'val': torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True) } class_names = ['ants', 'bees'] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 134-137 Visualize a Few Images ~~~~~~~~~~~~~~~~~~~~~~ Let's visualize a few training images so as to understand the data augmentation. .. GENERATED FROM PYTHON SOURCE LINES 137-159 .. code-block:: default def imshow(inp, title=None): """Imshow for Tensor.""" inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1) plt.imshow(inp) if title is not None: plt.title(title) plt.pause(0.001) # pause a bit so that plots are updated # Get a batch of training data inputs, classes = next(iter(dataloaders['train'])) # Make a grid from batch out = torchvision.utils.make_grid(inputs) imshow(out, title=[class_names[x] for x in classes]) .. image-sg:: /user-guide/vision/auto_tutorials/images/sphx_glr_plot_classification_tutorial_001.png :alt: ['bees', 'ants', 'ants', 'ants'] :srcset: /user-guide/vision/auto_tutorials/images/sphx_glr_plot_classification_tutorial_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 160-167 .. image :: /_static/images/tutorials/ants-bees.png :width: 400 :alt: Ants and Bees Downloading a pre-trained model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Now, we will download a pre-trained model from torchvision, that was trained on the ImageNet dataset. .. GENERATED FROM PYTHON SOURCE LINES 167-175 .. code-block:: default model = torchvision.models.resnet18(pretrained=True) num_ftrs = model.fc.in_features # We have only 2 classes model.fc = nn.Linear(num_ftrs, 2) model = model.to(device) _ = model.eval() .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/runner/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth 0%| | 0.00/44.7M [00:00 First element is: with len of 4 Example output of an image shape from the dataloader torch.Size([3, 224, 224]) Image values tensor([[[ 1.22142, 1.10155, 1.11867, ..., -0.26843, -0.45680, -0.69655], [ 1.17005, 1.17005, 1.13580, ..., -0.18281, -0.35405, -0.59380], [ 1.15292, 1.06730, 1.10155, ..., -0.18281, -0.40543, -0.59380], ..., [ 2.04341, 2.06054, 2.09479, ..., -0.38830, -0.30268, -0.26843], [ 2.11191, 2.04341, 1.97491, ..., -0.35405, -0.28556, -0.25131], [ 2.02629, 2.00916, 2.00916, ..., -0.31980, -0.31980, -0.28556]], [[ 0.43277, 0.18768, 0.17017, ..., -0.60014, -0.72269, -0.88025], [ 0.25770, 0.13515, 0.11765, ..., -0.60014, -0.74020, -0.88025], [ 0.34524, 0.13515, 0.10014, ..., -0.53011, -0.68767, -0.77521], ..., [ 2.37605, 2.39356, 2.42857, ..., -0.77521, -0.68767, -0.67017], [ 2.37605, 2.42857, 2.42857, ..., -0.74020, -0.70518, -0.67017], [ 2.37605, 2.39356, 2.39356, ..., -0.77521, -0.70518, -0.67017]], [[-1.78702, -1.78702, -1.76959, ..., -1.76959, -1.76959, -1.76959], [-1.78702, -1.76959, -1.76959, ..., -1.76959, -1.78702, -1.76959], [-1.76959, -1.78702, -1.78702, ..., -1.80444, -1.76959, -1.78702], ..., [-1.40357, -1.31643, -1.21185, ..., -1.78702, -1.78702, -1.80444], [-1.42100, -1.36871, -1.29900, ..., -1.73473, -1.80444, -1.80444], [-1.69987, -1.45586, -1.08985, ..., -1.78702, -1.78702, -1.80444]]]) -------------------------------------------------------------------------------- Second element is: with len of 4 Example output of a label shape from the dataloader torch.Size([]) Image values tensor(1) .. GENERATED FROM PYTHON SOURCE LINES 198-204 Implementing the ClassificationData class ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The first step is to implement a class that enables deepchecks to interact with your model and data. The appropriate class to implement should be selected according to you models task type. In this tutorial, we will implement the classification task type by implementing a class that inherits from the :class:`deepchecks.vision.classification_data.ClassificationData` class. .. GENERATED FROM PYTHON SOURCE LINES 204-243 .. code-block:: default # The goal of this class is to make sure the outputs of the model and of the dataloader are in the correct format. # To learn more about the expected format please visit the API reference for the # :class:`deepchecks.vision.classification_data.ClassificationData` class. class AntsBeesData(ClassificationData): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def batch_to_images(self, batch): """ Convert a batch of data to images in the expected format. The expected format is an iterable of cv2 images, where each image is a numpy array of shape (height, width, channels). The numbers in the array should be in the range [0, 255] """ inp = batch[0].detach().numpy().transpose((0, 2, 3, 1)) mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] inp = std * inp + mean inp = np.clip(inp, 0, 1) return inp*255 def batch_to_labels(self, batch): """ Convert a batch of data to labels in the expected format. The expected format is a tensor of shape (N,), where N is the number of samples. Each element is an integer representing the class index. """ return batch[1] def infer_on_batch(self, batch, model, device): """ Returns the predictions for a batch of data. The expected format is a tensor of shape (N, n_classes), where N is the number of samples. Each element is an array of length n_classes that represent the probability of each class. """ logits = model.to(device)(batch[0].to(device)) return nn.Softmax(dim=1)(logits) .. GENERATED FROM PYTHON SOURCE LINES 244-245 After defining the task class, we can validate it by running the following code: .. GENERATED FROM PYTHON SOURCE LINES 245-256 .. code-block:: default LABEL_MAP = { 0: 'ants', 1: 'bees' } training_data = AntsBeesData(data_loader=dataloaders["train"], label_map=LABEL_MAP) val_data = AntsBeesData(data_loader=dataloaders["val"], label_map=LABEL_MAP) training_data.validate_format(model) val_data.validate_format(model) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Deepchecks will try to validate the extractors given... Structure validation -------------------- Label formatter: Pass! Prediction formatter: Pass! Image formatter: Pass! Content validation ------------------ For validating the content within the structure you have to manually observe the classes, image, label and prediction. Examples of classes observed in the batch's labels: [[0], [0], [1], [0]] Visual images & label & prediction: should open in a new window ******************************************************************************* This machine does not support GUI The formatted image was saved in: /home/runner/work/deepchecks/deepchecks/docs/source/user-guide/vision/tutorials/deepchecks_formatted_image (2).jpg Visual example of an image. Label class 0 Prediction class 1 validate_extractors can be set to skip the image saving or change the save path ******************************************************************************* Deepchecks will try to validate the extractors given... Structure validation -------------------- Label formatter: Pass! Prediction formatter: Pass! Image formatter: Pass! Content validation ------------------ For validating the content within the structure you have to manually observe the classes, image, label and prediction. Examples of classes observed in the batch's labels: [[0], [1], [1], [0]] Visual images & label & prediction: should open in a new window ******************************************************************************* This machine does not support GUI The formatted image was saved in: /home/runner/work/deepchecks/deepchecks/docs/source/user-guide/vision/tutorials/deepchecks_formatted_image (3).jpg Visual example of an image. Label class 0 Prediction class 1 validate_extractors can be set to skip the image saving or change the save path ******************************************************************************* .. GENERATED FROM PYTHON SOURCE LINES 257-263 And observe the output: Running Deepchecks' full suite on our data and model! ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Now that we have defined the task class, we can validate the model with the full suite of deepchecks. This can be done with this simple few lines of code: .. GENERATED FROM PYTHON SOURCE LINES 263-269 .. code-block:: default from deepchecks.vision.suites import full_suite suite = full_suite() result = suite.run(training_data, val_data, model, device=device) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Validating Input: | | 0/1 [00:00 Full Suite

.. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 41.123 seconds) .. _sphx_glr_download_user-guide_vision_auto_tutorials_plot_classification_tutorial.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_classification_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_classification_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_