.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "user-guide/vision/auto_tutorials/plot_classification_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_user-guide_vision_auto_tutorials_plot_classification_tutorial.py: ============================================== Classification Model Validation Tutorial ============================================== In this tutorial, you will learn how to validate your **classification model** using deepchecks test suites. You can read more about the different checks and suites for computer vision use cases at the :doc:`examples section ` A classification model is usually used to classify an image into one of a number of classes. Although there are multi label use-cases, in which the model is used to classify an image into multiple classes, most use-cases require the model to classify images into a single class. Currently deepchecks supports only single label classification (either binary or multi-class). .. GENERATED FROM PYTHON SOURCE LINES 17-19 Defining the data and model =========================== .. GENERATED FROM PYTHON SOURCE LINES 19-40 .. code-block:: default import os import urllib.request import zipfile import albumentations as A import cv2 import matplotlib.pyplot as plt import numpy as np import PIL.Image import torch # Importing the required packages import torchvision from albumentations.pytorch import ToTensorV2 from torch import nn from torchvision import datasets, models, transforms from torchvision.datasets import ImageFolder import deepchecks from deepchecks.vision.classification_data import ClassificationData .. GENERATED FROM PYTHON SOURCE LINES 41-44 Downloading the dataset ~~~~~~~~~~~~~~~~~~~~~~~ The data is available from the torch library. We will download and extract it to the current directory. .. GENERATED FROM PYTHON SOURCE LINES 44-50 .. code-block:: default url = 'https://download.pytorch.org/tutorial/hymenoptera_data.zip' urllib.request.urlretrieve(url, 'hymenoptera_data.zip') with zipfile.ZipFile('hymenoptera_data.zip', 'r') as zip_ref: zip_ref.extractall('.') .. GENERATED FROM PYTHON SOURCE LINES 51-59 Load Data ~~~~~~~~~ We will use torchvision and torch.utils.data packages for loading the data. The model we are building will learn to classify **ants** and **bees**. We have about 120 training images each for ants and bees. There are 75 validation images for each class. This dataset is a very small subset of imagenet. .. GENERATED FROM PYTHON SOURCE LINES 59-133 .. code-block:: default class AntsBeesDataset(ImageFolder): def __init__(self, *args, **kwargs): """ Overrides initialization method to replace default loader with OpenCV loader :param args: :param kwargs: """ super(AntsBeesDataset, self).__init__(*args, **kwargs) def __getitem__(self, index: int): """ overrides __getitem__ to be compatible to albumentations Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) sample = self.get_cv2_image(sample) if self.transforms is not None: transformed = self.transforms(image=sample, target=target) sample, target = transformed["image"], transformed["target"] else: if self.transform is not None: sample = self.transform(image=sample)['image'] if self.target_transform is not None: target = self.target_transform(target) return sample, target def get_cv2_image(self, image): if isinstance(image, PIL.Image.Image): image_np = np.array(image).astype('uint8') return image_np elif isinstance(image, np.ndarray): return image else: raise RuntimeError("Only PIL.Image and CV2 loaders currently supported!") # Just normalization for validation data_transforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) data_dir = 'hymenoptera_data' # Just normalization for validation data_transforms = A.Compose([ A.Resize(height=256, width=256), A.CenterCrop(height=224, width=224), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ]) train_dataset = AntsBeesDataset(root=os.path.join(data_dir,'train')) train_dataset.transforms = data_transforms val_dataset = AntsBeesDataset(root=os.path.join(data_dir,'val')) val_dataset.transforms = data_transforms dataloaders = { 'train':torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True), 'val': torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True) } class_names = ['ants', 'bees'] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 134-137 Visualize a Few Images ~~~~~~~~~~~~~~~~~~~~~~ Let's visualize a few training images so as to understand the data augmentation. .. GENERATED FROM PYTHON SOURCE LINES 137-159 .. code-block:: default def imshow(inp, title=None): """Imshow for Tensor.""" inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1) plt.imshow(inp) if title is not None: plt.title(title) plt.pause(0.001) # pause a bit so that plots are updated # Get a batch of training data inputs, classes = next(iter(dataloaders['train'])) # Make a grid from batch out = torchvision.utils.make_grid(inputs) imshow(out, title=[class_names[x] for x in classes]) .. image-sg:: /user-guide/vision/auto_tutorials/images/sphx_glr_plot_classification_tutorial_001.png :alt: ['bees', 'ants', 'bees', 'bees'] :srcset: /user-guide/vision/auto_tutorials/images/sphx_glr_plot_classification_tutorial_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 160-167 .. image :: /_static/ants-bees.png :width: 400 :alt: Ants and Bees Downloading a pre-trained model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Now, we will download a pre-trained model from torchvision, that was trained on the ImageNet dataset. .. GENERATED FROM PYTHON SOURCE LINES 167-175 .. code-block:: default model = torchvision.models.resnet18(pretrained=True) num_ftrs = model.fc.in_features # We have only 2 classes model.fc = nn.Linear(num_ftrs, 2) model = model.to(device) _ = model.eval() .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/runner/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth 0%| | 0.00/44.7M [00:00 First element is: with len of 4 Example output of an image shape from the dataloader torch.Size([3, 224, 224]) Image values tensor([[[ 0.07406, 0.02269, 0.00557, ..., 0.65631, 0.77618, 0.86180], [ 0.02269, -0.02868, -0.06293, ..., 0.82755, 0.75905, 0.82755], [ 0.03982, 0.07406, -0.04581, ..., 0.72481, 0.79330, 0.86180], ..., [-0.49105, -0.38830, 0.15969, ..., 0.63918, 0.09119, 0.10831], [-0.31980, 0.00557, -0.42255, ..., 0.21106, -0.49105, 0.86180], [-0.54243, -0.18281, -0.18281, ..., 0.46793, -0.66230, 0.89605]], [[ 0.22269, 0.22269, 0.25770, ..., 0.92297, 0.92297, 1.01050], [ 0.18768, 0.11765, 0.15266, ..., 0.94048, 0.95798, 1.01050], [ 0.15266, 0.15266, 0.18768, ..., 0.88796, 0.94048, 1.02801], ..., [-0.32003, -0.17997, 0.36275, ..., 0.87045, 0.32773, 0.27521], [ 0.01261, 0.22269, -0.39006, ..., 0.45028, -0.44258, 0.90546], [-0.49510, -0.07493, -0.02241, ..., 0.66036, -0.51261, 1.08053]], [[ 0.46135, 0.51364, 0.56593, ..., 1.28052, 1.26309, 1.36767], [ 0.51364, 0.49621, 0.51364, ..., 1.24566, 1.29795, 1.35024], [ 0.44392, 0.49621, 0.56593, ..., 1.19338, 1.28052, 1.36767], ..., [ 0.07791, 0.46135, 0.80993, ..., 1.24566, 0.86222, 0.72279], [ 0.42649, 0.26963, -0.42754, ..., 0.74022, -0.09638, 0.84479], [-0.06153, 0.30449, 0.09534, ..., 1.03651, -0.16610, 1.24566]]]) -------------------------------------------------------------------------------- Second element is: with len of 4 Example output of a label shape from the dataloader torch.Size([]) Image values tensor(0) .. GENERATED FROM PYTHON SOURCE LINES 198-204 Implementing the ClassificationData class ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The first step is to implement a class that enables deepchecks to interact with your model and data. The appropriate class to implement should be selected according to you models task type. In this tutorial, we will implement the classification task type by implementing a class that inherits from the :class:`deepchecks.vision.classification_data.ClassificationData` class. .. GENERATED FROM PYTHON SOURCE LINES 204-243 .. code-block:: default # The goal of this class is to make sure the outputs of the model and of the dataloader are in the correct format. # To learn more about the expected format please visit the API reference for the # :class:`deepchecks.vision.classification_data.ClassificationData` class. class AntsBeesData(ClassificationData): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def batch_to_images(self, batch): """ Convert a batch of data to images in the expected format. The expected format is an iterable of cv2 images, where each image is a numpy array of shape (height, width, channels). The numbers in the array should be in the range [0, 255] """ inp = batch[0].detach().numpy().transpose((0, 2, 3, 1)) mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] inp = std * inp + mean inp = np.clip(inp, 0, 1) return inp*255 def batch_to_labels(self, batch): """ Convert a batch of data to labels in the expected format. The expected format is a tensor of shape (N,), where N is the number of samples. Each element is an integer representing the class index. """ return batch[1] def infer_on_batch(self, batch, model, device): """ Returns the predictions for a batch of data. The expected format is a tensor of shape (N, n_classes), where N is the number of samples. Each element is an array of length n_classes that represent the probability of each class. """ logits = model.to(device)(batch[0].to(device)) return nn.Softmax(dim=1)(logits) .. GENERATED FROM PYTHON SOURCE LINES 244-245 After defining the task class, we can validate it by running the following code: .. GENERATED FROM PYTHON SOURCE LINES 245-256 .. code-block:: default LABEL_MAP = { 0: 'ants', 1: 'bees' } training_data = AntsBeesData(data_loader=dataloaders["train"], label_map=LABEL_MAP) val_data = AntsBeesData(data_loader=dataloaders["val"], label_map=LABEL_MAP) training_data.validate_format(model) val_data.validate_format(model) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Deepchecks will try to validate the extractors given... Structure validation -------------------- Label formatter: Pass! Prediction formatter: Pass! Image formatter: Pass! Content validation ------------------ For validating the content within the structure you have to manually observe the classes, image, label and prediction. Examples of classes observed in the batch's labels: [[1], [1], [0], [1]] Visual images & label & prediction: should open in a new window ******************************************************************************* This machine does not support GUI The formatted image was saved in: /home/runner/work/deepchecks/deepchecks/docs/source/user-guide/vision/tutorials/deepchecks_formatted_image (2).jpg Visual example of an image. Label class 1 Prediction class 1 validate_extractors can be set to skip the image saving or change the save path ******************************************************************************* Deepchecks will try to validate the extractors given... Structure validation -------------------- Label formatter: Pass! Prediction formatter: Pass! Image formatter: Pass! Content validation ------------------ For validating the content within the structure you have to manually observe the classes, image, label and prediction. Examples of classes observed in the batch's labels: [[0], [1], [0], [1]] Visual images & label & prediction: should open in a new window ******************************************************************************* This machine does not support GUI The formatted image was saved in: /home/runner/work/deepchecks/deepchecks/docs/source/user-guide/vision/tutorials/deepchecks_formatted_image (3).jpg Visual example of an image. Label class 0 Prediction class 1 validate_extractors can be set to skip the image saving or change the save path ******************************************************************************* .. GENERATED FROM PYTHON SOURCE LINES 257-263 And observe the output: Running Deepchecks' full suite on our data and model! ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Now that we have defined the task class, we can validate the model with the full suite of deepchecks. This can be done with this simple few lines of code: .. GENERATED FROM PYTHON SOURCE LINES 263-269 .. code-block:: default from deepchecks.vision.suites import full_suite suite = full_suite() result = suite.run(training_data, val_data, model, device=device) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Validating Input: 0%| | 0/1 [00:00

Full Suite

The suite is composed of various checks such as: Image Dataset Drift, Similar Image Leakage, Image Property Drift, etc...
Each check may contain conditions (which will result in pass / fail / warning / error , represented by / / ! / ) as well as other outputs such as plots or tables.
Suites, checks and conditions can all be modified. Read more about custom suites.


Conditions Summary

Status Check Condition More Info
Simple Model Comparison Model performance gain over simple model is not less than 10% Found metrics with gain below threshold: {'F1': {0: '-17.38%'}}
Image Segment Performance - Test Dataset No segment with ratio between score to mean less than 80% Properties with failed segments: Brightness: {'Range': '[0.51, 0.57)', 'Metric': 'Precision', 'Ratio': 0.76}
Similar Image Leakage Number of similar images between train and test is not greater than 0 Number of similar images between train and test datasets: 1
Class Performance Train-Test scores relative degradation is not greater than 0.1
Train Test Prediction Drift PSI <= 0.15 and Earth Mover's Distance <= 0.075 for prediction drift
Image Property Drift Earth Mover's Distance <= 0.1 for image properties drift
Simple Feature Contribution Train-Test properties' Predictive Power Score difference is not greater than 0.2
New Labels Percentage of new labels in the test set not above 0.5%.
Image Segment Performance - Train Dataset No segment with ratio between score to mean less than 80%
Train Test Label Drift PSI <= 0.15 and Earth Mover's Distance <= 0.075 for label drift

Check With Conditions Output

Class Performance

Summarize given metrics on a dataset and model.

Conditions Summary
Status Condition More Info
Train-Test scores relative degradation is not greater than 0.1
Additional Outputs

Go to top

Train Test Prediction Drift

Calculate prediction drift between train dataset and test dataset, using statistical measures.

Conditions Summary
Status Condition More Info
PSI <= 0.15 and Earth Mover's Distance <= 0.075 for prediction drift
Additional Outputs
The Drift score is a measure for the difference between two distributions. In this check, drift is measured for the distribution of the following prediction properties: ['Samples Per Class'].

Go to top

Image Property Drift

Calculate drift between train dataset and test dataset per image property, using statistical measures.

Conditions Summary
Status Condition More Info
Earth Mover's Distance <= 0.1 for image properties drift
Additional Outputs
The Drift score is a measure for the difference between two distributions. In this check, drift is measured for the distribution of the following image properties: ['Area', 'Aspect Ratio', 'Brightness', 'Mean Blue Relative Intensity', 'Mean Green Relative Intensity', 'Mean Red Relative Intensity', 'RMS Contrast'].

Go to top

Simple Feature Contribution

Return the Predictive Power Score of image properties, in order to estimate their ability to predict the label.

Conditions Summary
Status Condition More Info
Train-Test properties' Predictive Power Score difference is not greater than 0.2
Additional Outputs
The Predictive Power Score (PPS) is used to estimate the ability of an image property (such as brightness)to predict the label by itself. (Read more about Predictive Power Score)
In the graph above, we should suspect we have problems in our data if:
1. Train dataset PPS values are high:
A high PPS (close to 1) can mean that there's a bias in the dataset, as a single property can predict the label successfully, using simple classic ML algorithms
2. Large difference between train and test PPS (train PPS is larger):
An even more powerful indication of dataset bias, as an image property that was powerful in train
but not in test can be explained by bias in train that is not relevant to a new dataset.
3. Large difference between test and train PPS (test PPS is larger):
An anomalous value, could indicate drift in test dataset that caused a coincidental correlation to the target label.

Go to top

Simple Model Comparison

Compare given model score to simple model score (according to given model type).

Conditions Summary
Status Condition More Info
Model performance gain over simple model is not less than 10% Found metrics with gain below threshold: {'F1': {0: '-17.38%'}}
Additional Outputs

Go to top

Image Segment Performance - Test Dataset

Segment the data by various properties of the image, and compare the performance of the segments.

Conditions Summary
Status Condition More Info
No segment with ratio between score to mean less than 80% Properties with failed segments: Brightness: {'Range': '[0.51, 0.57)', 'Metric': 'Precision', 'Ratio': 0.76}
Additional Outputs

Go to top

Image Segment Performance - Train Dataset

Segment the data by various properties of the image, and compare the performance of the segments.

Conditions Summary
Status Condition More Info
No segment with ratio between score to mean less than 80%
Additional Outputs

Go to top

Similar Image Leakage

Check for images in training that are similar to images in test.

Conditions Summary
Status Condition More Info
Number of similar images between train and test is not greater than 0 Number of similar images between train and test datasets: 1
Additional Outputs

Similar Images

Total number of test samples with similar images in train: 1

Samples

Train
Test

Go to top

Train Test Label Drift

Calculate label drift between train dataset and test dataset, using statistical measures.

Conditions Summary
Status Condition More Info
PSI <= 0.15 and Earth Mover's Distance <= 0.075 for label drift
Additional Outputs
The Drift score is a measure for the difference between two distributions. In this check, drift is measured for the distribution of the following label properties: ['Samples Per Class'].

Go to top

Check Without Conditions Output

Image Property Outliers - Test Dataset

Find outliers images with respect to the given properties.

Additional Outputs

Property "Aspect Ratio"

No outliers found.

Property "Area"

No outliers found.

Property "Brightness"

Total number of outliers: 1
Non-outliers range: 0.2 to 0.74
Brightness
0.19
Image

Property "RMS Contrast"

No outliers found.

Property "Mean Red Relative Intensity"

Total number of outliers: 8
Non-outliers range: 0.22 to 0.55
Mean Red Relative Intensity
0.15
0.21
0.61
0.63
0.66
0.67
0.7
Image

Property "Mean Green Relative Intensity"

Total number of outliers: 11
Non-outliers range: 0.23 to 0.49
Mean Green Relative Intensity
0.16
0.21
0.55
0.55
0.66
0.72
0.74
Image

Property "Mean Blue Relative Intensity"

No outliers found.

Go to top

Image Property Outliers - Train Dataset

Find outliers images with respect to the given properties.

Additional Outputs

Property "Aspect Ratio"

No outliers found.

Property "Area"

No outliers found.

Property "Brightness"

Total number of outliers: 5
Non-outliers range: 0.21 to 0.75
Brightness
0.76
0.78
0.79
0.82
0.91
Image

Property "RMS Contrast"

Total number of outliers: 3
Non-outliers range: 0.08 to 0.34
RMS Contrast
0.07
0.35
0.37
Image

Property "Mean Red Relative Intensity"

Total number of outliers: 14
Non-outliers range: 0.22 to 0.56
Mean Red Relative Intensity
0.13
0.16
0.18
0.2
0.6
0.63
0.65
0.72
0.83
Image

Property "Mean Green Relative Intensity"

Total number of outliers: 15
Non-outliers range: 0.21 to 0.52
Mean Green Relative Intensity
0.16
0.17
0.17
0.17
0.18
0.55
0.57
0.57
0.58
0.59
Image

Property "Mean Blue Relative Intensity"

No outliers found.

Go to top

Confusion Matrix - Test Dataset

Calculate the confusion matrix of the model on the given dataset.

Additional Outputs
Showing 10 of 2 classes:

Go to top

Confusion Matrix - Train Dataset

Calculate the confusion matrix of the model on the given dataset.

Additional Outputs
Showing 10 of 2 classes:

Go to top

Heatmap Comparison

Check if the average image brightness (or bbox location if applicable) is similar between train and test set.

Additional Outputs

Go to top

Other Checks That Weren't Displayed

Check Reason
Mean Average Precision Report - Test Dataset Check is irrelevant for task of type TaskType.CLASSIFICATION
Mean Average Precision Report - Train Dataset Check is irrelevant for task of type TaskType.CLASSIFICATION
Label Property Outliers - Test Dataset task type classification does not have default label properties for label outliers.
Label Property Outliers - Train Dataset task type classification does not have default label properties for label outliers.
Mean Average Recall Report - Test Dataset Check is irrelevant for task of type TaskType.CLASSIFICATION
Mean Average Recall Report - Train Dataset Check is irrelevant for task of type TaskType.CLASSIFICATION
Model Error Analysis Unable to train meaningful error model (r^2 score: 0.04)
Image Dataset Drift Nothing found
New Labels Nothing found

Go to top

.. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 26.287 seconds) .. _sphx_glr_download_user-guide_vision_auto_tutorials_plot_classification_tutorial.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_classification_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_classification_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_