Train Test Prediction Drift#

This notebooks provides an overview for using and understanding the vision prediction drift check.

Structure:

What Is a Prediction Drift?#

The term drift (and all it’s derivatives) is used to describe any change in the data compared to the data the model was trained on. Prediction drift refers to the case in which a change in the data (data/feature drift) has happened and as a result, the distribution of the models’ prediction has changed.

Calculating prediction drift is especially useful in cases in which labels are not available for the test dataset, and so a drift in the predictions is out only indication that a changed has happened in the data that actually affects model predictions. If labels are available, it’s also recommended to run the :doc:`Label Drift Check </examples/vision/checks/distribution/examples/plot_train_test_label_drift>.

There are two main causes for prediction drift:

  • A change in the sample population. In this case, the underline phenomenon we’re trying to predict behaves the same, but we’re not getting the same types of samples. For example, cronuts becoming more popular in a food classification dataset.

  • Concept drift, which means that the underline relation between the data and the label has changed. For example, the arctic hare changes its fur color during the winter. A dataset that was trained on summertime hares, would have difficulty identifying them in winter. Important to note that concept drift won’t necessarily result in prediction drift, unless it affects features that are of high importance to the model.

How Does the TrainTestPredictionDrift Check Work?#

There are many methods to detect drift, that usually include statistical methods that aim to measure difference between 2 distributions. We experimented with various approaches and found that for detecting drift between 2 one-dimensional distributions, the following 2 methods give the best results:

However, one does not simply measure drift on a prediction, as they may be complex structures. These methods are implemented on label properties, as described in the next section.

Different measurement on predictions#

In computer vision specifically, our predictions may be complex, and measuring their drift is not a straightforward task. Therefore, we calculate drift on different properties of the labels, on which we can directly measure drift.

Which Prediction Properties Are Used?#

Task Type

Property name

What is it

Classification

Samples Per Class

Number of images per class

Object Detection

Samples Per Class

Number of bounding boxes per class

Object Detection

Bounding Box Area

Area of bounding box (height * width)

Object Detection

Number of Bounding Boxes Per Image

Number of bounding box objects in each image

Run the check on a Classification task (MNIST)#

Imports#

from deepchecks.vision.checks import TrainTestPredictionDrift
from deepchecks.vision.datasets.classification.mnist import (load_dataset,
                                                             load_model)

Loading data and model:#

train_ds = load_dataset(train=True, batch_size=64, object_type='VisionData')
test_ds = load_dataset(train=False, batch_size=64, object_type='VisionData')
model = load_model()

Running TrainTestLabelDrift on classification#

Out:

Validating Input:   0%| | 0/1 [00:00<?, ? /s]


Ingesting Batches - Train Dataset:   0%|                                                                                                                                                             | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset:  10%|###############                                                                                                                                              | 15/157 [00:00<00:00, 144.23 Batch/s]
Ingesting Batches - Train Dataset:  20%|################################                                                                                                                             | 32/157 [00:00<00:00, 159.13 Batch/s]
Ingesting Batches - Train Dataset:  31%|#################################################                                                                                                            | 49/157 [00:00<00:00, 163.05 Batch/s]
Ingesting Batches - Train Dataset:  42%|##################################################################                                                                                           | 66/157 [00:00<00:00, 164.89 Batch/s]
Ingesting Batches - Train Dataset:  53%|###################################################################################                                                                          | 83/157 [00:00<00:00, 157.65 Batch/s]
Ingesting Batches - Train Dataset:  64%|####################################################################################################                                                         | 100/157 [00:00<00:00, 159.87 Batch/s]
Ingesting Batches - Train Dataset:  75%|####################################################################################################################9                                        | 117/157 [00:00<00:00, 157.54 Batch/s]
Ingesting Batches - Train Dataset:  85%|######################################################################################################################################                       | 134/157 [00:00<00:00, 160.68 Batch/s]
Ingesting Batches - Train Dataset:  96%|#######################################################################################################################################################      | 151/157 [00:00<00:00, 161.25 Batch/s]


Ingesting Batches - Test Dataset:   0%|                                                                                                                                                             | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset:  11%|#################                                                                                                                                            | 17/157 [00:00<00:00, 163.20 Batch/s]
Ingesting Batches - Test Dataset:  22%|##################################                                                                                                                           | 34/157 [00:00<00:00, 164.93 Batch/s]
Ingesting Batches - Test Dataset:  32%|###################################################                                                                                                          | 51/157 [00:00<00:00, 155.83 Batch/s]
Ingesting Batches - Test Dataset:  43%|####################################################################                                                                                         | 68/157 [00:00<00:00, 159.03 Batch/s]
Ingesting Batches - Test Dataset:  54%|#####################################################################################                                                                        | 85/157 [00:00<00:00, 161.12 Batch/s]
Ingesting Batches - Test Dataset:  65%|######################################################################################################                                                       | 102/157 [00:00<00:00, 161.82 Batch/s]
Ingesting Batches - Test Dataset:  76%|#######################################################################################################################                                      | 119/157 [00:00<00:00, 163.63 Batch/s]
Ingesting Batches - Test Dataset:  87%|########################################################################################################################################                     | 136/157 [00:00<00:00, 164.86 Batch/s]
Ingesting Batches - Test Dataset:  97%|#########################################################################################################################################################    | 153/157 [00:00<00:00, 164.74 Batch/s]


Computing Check:   0%| | 0/1 [00:00<?, ? Check/s]

Train Test Prediction Drift

Calculate prediction drift between train dataset and test dataset, using statistical measures.

Additional Outputs
The Drift score is a measure for the difference between two distributions. In this check, drift is measured for the distribution of the following prediction properties: ['Samples Per Class'].

Note - data sampling: Running on 10000 train data samples out of 60000. Sample size can be controlled with the "n_samples" parameter.



Understanding the results#

We can see there is almost no drift between the train & test labels. This means the split to train and test was good (as it is balanced and random). Let’s check the performance of a simple model trained on MNIST.

from deepchecks.vision.checks import ClassPerformance

ClassPerformance().run(train_ds, test_ds, model)

Out:

Validating Input:   0%| | 0/1 [00:00<?, ? /s]


Ingesting Batches - Train Dataset:   0%|                                                                                                                                                             | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset:  10%|################                                                                                                                                             | 16/157 [00:00<00:00, 153.42 Batch/s]
Ingesting Batches - Train Dataset:  21%|#################################                                                                                                                            | 33/157 [00:00<00:00, 157.50 Batch/s]
Ingesting Batches - Train Dataset:  32%|##################################################                                                                                                           | 50/157 [00:00<00:00, 161.42 Batch/s]
Ingesting Batches - Train Dataset:  43%|###################################################################                                                                                          | 67/157 [00:00<00:00, 162.92 Batch/s]
Ingesting Batches - Train Dataset:  54%|####################################################################################                                                                         | 84/157 [00:00<00:00, 164.22 Batch/s]
Ingesting Batches - Train Dataset:  64%|#####################################################################################################                                                        | 101/157 [00:00<00:00, 164.85 Batch/s]
Ingesting Batches - Train Dataset:  75%|######################################################################################################################                                       | 118/157 [00:00<00:00, 165.18 Batch/s]
Ingesting Batches - Train Dataset:  86%|#######################################################################################################################################                      | 135/157 [00:00<00:00, 165.74 Batch/s]
Ingesting Batches - Train Dataset:  97%|########################################################################################################################################################     | 152/157 [00:00<00:00, 165.44 Batch/s]


Ingesting Batches - Test Dataset:   0%|                                                                                                                                                             | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset:  11%|#################                                                                                                                                            | 17/157 [00:00<00:00, 166.85 Batch/s]
Ingesting Batches - Test Dataset:  22%|##################################                                                                                                                           | 34/157 [00:00<00:00, 167.46 Batch/s]
Ingesting Batches - Test Dataset:  32%|###################################################                                                                                                          | 51/157 [00:00<00:00, 167.48 Batch/s]
Ingesting Batches - Test Dataset:  43%|####################################################################                                                                                         | 68/157 [00:00<00:00, 165.55 Batch/s]
Ingesting Batches - Test Dataset:  54%|#####################################################################################                                                                        | 85/157 [00:00<00:00, 163.61 Batch/s]
Ingesting Batches - Test Dataset:  65%|######################################################################################################                                                       | 102/157 [00:00<00:00, 160.90 Batch/s]
Ingesting Batches - Test Dataset:  76%|#######################################################################################################################                                      | 119/157 [00:00<00:00, 160.29 Batch/s]
Ingesting Batches - Test Dataset:  87%|########################################################################################################################################                     | 136/157 [00:00<00:00, 158.35 Batch/s]
Ingesting Batches - Test Dataset:  97%|########################################################################################################################################################     | 152/157 [00:00<00:00, 158.57 Batch/s]


Computing Check:   0%| | 0/1 [00:00<?, ? Check/s]

Class Performance

Summarize given metrics on a dataset and model.

Additional Outputs

Note - data sampling: Running on 10000 train data samples out of 60000. Sample size can be controlled with the "n_samples" parameter.



MNIST with label drift#

Now, let’s try to separate the MNIST dataset in a different manner that will result in a prediction drift, and see how it affects the performance. We are going to create a custom collate_fn in the test dataset, that will select samples with class 0 in a 1/10 chances.

import torch

mnist_dataloader_train = load_dataset(train=True, batch_size=64, object_type='DataLoader')
mnist_dataloader_test = load_dataset(train=False, batch_size=64, object_type='DataLoader')
full_mnist = torch.utils.data.ConcatDataset([mnist_dataloader_train.dataset, mnist_dataloader_test.dataset])
train_dataset, test_dataset = torch.utils.data.random_split(full_mnist, [60000,10000], generator=torch.Generator().manual_seed(42))

Inserting drift to the test set#

import numpy as np
from torch.utils.data._utils.collate import default_collate

np.random.seed(42)


def collate_test(batch):
    modified_batch = []
    for item in batch:
        image, label = item
        if label == 0:
            if np.random.randint(5) == 0:
                modified_batch.append(item)
            else:
                modified_batch.append((image, 1))
        else:
            modified_batch.append(item)

    return default_collate(modified_batch)

mod_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
mod_test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=collate_test)
from deepchecks.vision.datasets.classification.mnist import MNISTData

mod_train_ds = MNISTData(mod_train_loader)
mod_test_ds = MNISTData(mod_test_loader)

# Run the check
# -------------

check = TrainTestPredictionDrift()
check.run(mod_train_ds, mod_test_ds, model)

# Add a condition
# ---------------
# We could also add a condition to the check to alert us to changes in the prediction
# distribution, such as the one that occurred here.

check = TrainTestPredictionDrift().add_condition_drift_score_not_greater_than()
check.run(mod_train_ds, mod_test_ds, model)

Out:

Validating Input:   0%| | 0/1 [00:00<?, ? /s]


Ingesting Batches - Train Dataset:   0%|                                                                                                                                                             | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset:  11%|#################                                                                                                                                            | 17/157 [00:00<00:00, 166.31 Batch/s]
Ingesting Batches - Train Dataset:  22%|##################################                                                                                                                           | 34/157 [00:00<00:00, 166.78 Batch/s]
Ingesting Batches - Train Dataset:  32%|###################################################                                                                                                          | 51/157 [00:00<00:00, 166.78 Batch/s]
Ingesting Batches - Train Dataset:  43%|####################################################################                                                                                         | 68/157 [00:00<00:00, 167.13 Batch/s]
Ingesting Batches - Train Dataset:  54%|#####################################################################################                                                                        | 85/157 [00:00<00:00, 166.50 Batch/s]
Ingesting Batches - Train Dataset:  65%|######################################################################################################                                                       | 102/157 [00:00<00:00, 166.16 Batch/s]
Ingesting Batches - Train Dataset:  76%|#######################################################################################################################                                      | 119/157 [00:00<00:00, 166.44 Batch/s]
Ingesting Batches - Train Dataset:  87%|########################################################################################################################################                     | 136/157 [00:00<00:00, 166.28 Batch/s]
Ingesting Batches - Train Dataset:  97%|#########################################################################################################################################################    | 153/157 [00:00<00:00, 165.34 Batch/s]


Ingesting Batches - Test Dataset:   0%|                                                                                                                                                             | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset:  11%|#################                                                                                                                                            | 17/157 [00:00<00:00, 163.49 Batch/s]
Ingesting Batches - Test Dataset:  22%|##################################                                                                                                                           | 34/157 [00:00<00:00, 163.97 Batch/s]
Ingesting Batches - Test Dataset:  32%|###################################################                                                                                                          | 51/157 [00:00<00:00, 163.13 Batch/s]
Ingesting Batches - Test Dataset:  43%|####################################################################                                                                                         | 68/157 [00:00<00:00, 160.78 Batch/s]
Ingesting Batches - Test Dataset:  54%|#####################################################################################                                                                        | 85/157 [00:00<00:00, 159.96 Batch/s]
Ingesting Batches - Test Dataset:  65%|######################################################################################################                                                       | 102/157 [00:00<00:00, 160.74 Batch/s]
Ingesting Batches - Test Dataset:  76%|#######################################################################################################################                                      | 119/157 [00:00<00:00, 161.05 Batch/s]
Ingesting Batches - Test Dataset:  87%|########################################################################################################################################                     | 136/157 [00:00<00:00, 160.79 Batch/s]
Ingesting Batches - Test Dataset:  97%|#########################################################################################################################################################    | 153/157 [00:00<00:00, 155.25 Batch/s]


Computing Check:   0%| | 0/1 [00:00<?, ? Check/s]


Validating Input:   0%| | 0/1 [00:00<?, ? /s]


Ingesting Batches - Train Dataset:   0%|                                                                                                                                                             | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset:  10%|################                                                                                                                                             | 16/157 [00:00<00:00, 159.49 Batch/s]
Ingesting Batches - Train Dataset:  21%|#################################                                                                                                                            | 33/157 [00:00<00:00, 160.35 Batch/s]
Ingesting Batches - Train Dataset:  32%|##################################################                                                                                                           | 50/157 [00:00<00:00, 159.77 Batch/s]
Ingesting Batches - Train Dataset:  43%|###################################################################                                                                                          | 67/157 [00:00<00:00, 161.05 Batch/s]
Ingesting Batches - Train Dataset:  54%|####################################################################################                                                                         | 84/157 [00:00<00:00, 161.30 Batch/s]
Ingesting Batches - Train Dataset:  64%|#####################################################################################################                                                        | 101/157 [00:00<00:00, 162.31 Batch/s]
Ingesting Batches - Train Dataset:  75%|######################################################################################################################                                       | 118/157 [00:00<00:00, 163.31 Batch/s]
Ingesting Batches - Train Dataset:  86%|#######################################################################################################################################                      | 135/157 [00:00<00:00, 163.98 Batch/s]
Ingesting Batches - Train Dataset:  97%|########################################################################################################################################################     | 152/157 [00:00<00:00, 164.75 Batch/s]


Ingesting Batches - Test Dataset:   0%|                                                                                                                                                             | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset:  11%|#################                                                                                                                                            | 17/157 [00:00<00:00, 167.17 Batch/s]
Ingesting Batches - Test Dataset:  22%|##################################                                                                                                                           | 34/157 [00:00<00:00, 166.76 Batch/s]
Ingesting Batches - Test Dataset:  32%|###################################################                                                                                                          | 51/157 [00:00<00:00, 166.44 Batch/s]
Ingesting Batches - Test Dataset:  43%|####################################################################                                                                                         | 68/157 [00:00<00:00, 166.05 Batch/s]
Ingesting Batches - Test Dataset:  54%|#####################################################################################                                                                        | 85/157 [00:00<00:00, 165.53 Batch/s]
Ingesting Batches - Test Dataset:  65%|######################################################################################################                                                       | 102/157 [00:00<00:00, 166.10 Batch/s]
Ingesting Batches - Test Dataset:  76%|#######################################################################################################################                                      | 119/157 [00:00<00:00, 163.98 Batch/s]
Ingesting Batches - Test Dataset:  87%|########################################################################################################################################                     | 136/157 [00:00<00:00, 164.07 Batch/s]
Ingesting Batches - Test Dataset:  97%|#########################################################################################################################################################    | 153/157 [00:00<00:00, 164.20 Batch/s]


Computing Check:   0%| | 0/1 [00:00<?, ? Check/s]

Train Test Prediction Drift

Calculate prediction drift between train dataset and test dataset, using statistical measures.

Conditions Summary
Status Condition More Info
PSI <= 0.15 and Earth Mover's Distance <= 0.075 for prediction drift
Additional Outputs
The Drift score is a measure for the difference between two distributions. In this check, drift is measured for the distribution of the following prediction properties: ['Samples Per Class'].

Note - data sampling: Running on 10000 train data samples out of 60000. Sample size can be controlled with the "n_samples" parameter.



As we can see, the condition alerts us to the present of drift in the prediction.

Results#

We can see the check successfully detects the (expected) drift in class 0 distribution between the train and test sets. It means the the model correctly predicted 0 for those samples and so we’re seeing drift in the predictions as well as the labels. We note that this check enabled us to detect the presence of label drift (in this case) without needing actual labels for the test data.

But how does this affect the performance of the model?#

Out:

Validating Input:   0%| | 0/1 [00:00<?, ? /s]


Ingesting Batches - Train Dataset:   0%|                                                                                                                                                             | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset:  10%|################                                                                                                                                             | 16/157 [00:00<00:00, 159.95 Batch/s]
Ingesting Batches - Train Dataset:  21%|#################################                                                                                                                            | 33/157 [00:00<00:00, 160.20 Batch/s]
Ingesting Batches - Train Dataset:  32%|##################################################                                                                                                           | 50/157 [00:00<00:00, 160.98 Batch/s]
Ingesting Batches - Train Dataset:  43%|###################################################################                                                                                          | 67/157 [00:00<00:00, 161.11 Batch/s]
Ingesting Batches - Train Dataset:  54%|####################################################################################                                                                         | 84/157 [00:00<00:00, 160.58 Batch/s]
Ingesting Batches - Train Dataset:  64%|#####################################################################################################                                                        | 101/157 [00:00<00:00, 159.44 Batch/s]
Ingesting Batches - Train Dataset:  75%|####################################################################################################################9                                        | 117/157 [00:00<00:00, 158.04 Batch/s]
Ingesting Batches - Train Dataset:  85%|######################################################################################################################################                       | 134/157 [00:00<00:00, 159.18 Batch/s]
Ingesting Batches - Train Dataset:  96%|#######################################################################################################################################################      | 151/157 [00:00<00:00, 160.08 Batch/s]


Ingesting Batches - Test Dataset:   0%|                                                                                                                                                             | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset:  10%|################                                                                                                                                             | 16/157 [00:00<00:00, 155.36 Batch/s]
Ingesting Batches - Test Dataset:  21%|#################################                                                                                                                            | 33/157 [00:00<00:00, 159.54 Batch/s]
Ingesting Batches - Test Dataset:  32%|##################################################                                                                                                           | 50/157 [00:00<00:00, 160.63 Batch/s]
Ingesting Batches - Test Dataset:  43%|###################################################################                                                                                          | 67/157 [00:00<00:00, 160.29 Batch/s]
Ingesting Batches - Test Dataset:  54%|####################################################################################                                                                         | 84/157 [00:00<00:00, 160.24 Batch/s]
Ingesting Batches - Test Dataset:  64%|#####################################################################################################                                                        | 101/157 [00:00<00:00, 160.87 Batch/s]
Ingesting Batches - Test Dataset:  75%|######################################################################################################################                                       | 118/157 [00:00<00:00, 160.92 Batch/s]
Ingesting Batches - Test Dataset:  86%|#######################################################################################################################################                      | 135/157 [00:00<00:00, 161.18 Batch/s]
Ingesting Batches - Test Dataset:  97%|########################################################################################################################################################     | 152/157 [00:00<00:00, 160.38 Batch/s]


Computing Check:   0%| | 0/1 [00:00<?, ? Check/s]

Class Performance

Summarize given metrics on a dataset and model.

Additional Outputs

Note - data sampling: Running on 10000 train data samples out of 60000. Sample size can be controlled with the "n_samples" parameter.



Inferring the results#

# We can see the drop in the precision of class 0, which was caused by the class
# imbalance indicated earlier by the label drift check.

Run the check on an Object Detection task (COCO)#

from deepchecks.vision.datasets.detection.coco import load_dataset, load_model

train_ds = load_dataset(train=True, object_type='VisionData')
test_ds = load_dataset(train=False, object_type='VisionData')
model = load_model(pretrained=True)

Out:

Downloading: "https://github.com/ultralytics/yolov5/archive/v6.1.zip" to /home/runner/.cache/torch/hub/v6.1.zip
Downloading https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5s.pt to yolov5s.pt...

  0%|          | 0.00/14.1M [00:00<?, ?B/s]
 45%|####5     | 6.38M/14.1M [00:00<00:00, 66.8MB/s]
100%|##########| 14.1M/14.1M [00:00<00:00, 109MB/s]

Out:

Validating Input:   0%| | 0/1 [00:00<?, ? /s]
Validating Input: 100%|#| 1/1 [00:11<00:00, 11.67s/ ]


Ingesting Batches - Train Dataset:   0%|  | 0/2 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset:  50%|# | 1/2 [00:05<00:05,  5.43s/ Batch]
Ingesting Batches - Train Dataset: 100%|##| 2/2 [00:10<00:00,  5.46s/ Batch]


Ingesting Batches - Test Dataset:   0%|  | 0/2 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset:  50%|# | 1/2 [00:05<00:05,  5.46s/ Batch]
Ingesting Batches - Test Dataset: 100%|##| 2/2 [00:11<00:00,  5.54s/ Batch]


Computing Check:   0%| | 0/1 [00:00<?, ? Check/s]
Computing Check: 100%|#| 1/1 [00:00<00:00,  9.24 Check/s]

Train Test Prediction Drift

Calculate prediction drift between train dataset and test dataset, using statistical measures.

Additional Outputs
The Drift score is a measure for the difference between two distributions. In this check, drift is measured for the distribution of the following prediction properties: ['Samples Per Class', 'Bounding Box Area (in pixels)', 'Number of Bounding Boxes Per Image'].


Prediction drift is detected!#

We can see that the COCO128 contains a drift in the out of the box dataset. In addition to the prediction count per class, the prediction drift check for object detection tasks include drift calculation on certain measurements, like the bounding box area and the number of bboxes per image.

Total running time of the script: ( 0 minutes 46.905 seconds)

Gallery generated by Sphinx-Gallery