Note
Click here to download the full example code
Train Test Label Drift#
This notebooks provides an overview for using and understanding the vision label drift check.
Structure:
What is a label drift?#
The term drift (and all it’s derivatives) is used to describe any change in the data compared to the data the model was trained on. Specifically, label drift indicates changes in the label we are trying to predict.
Causes of label drift include:
Natural drift in the data, such as a certain class becoming more prevalent in the test set. For example, cronuts becoming more popular in a food classification dataset.
Labeling issues, such as an analyst drawing incorrect bounding boxes for an object detection task.
How Does the TrainTestLabelDrift Check Work?#
There are many methods to detect drift, that usually include statistical methods that aim to measure difference between 2 distributions. We experimented with various approaches and found that for detecting drift between 2 one-dimensional distributions, the following 2 methods give the best results:
For numerical features, the Population Stability Index (PSI)
For categorical features, the Wasserstein Distance (Earth Mover’s Distance)
However, one does not simply measure drift on a label, as they may be complex structures. These methods are implemented on label properties, as described in the next section.
Using Label Properties to Detect Label Drift#
In computer vision specifically, our labels may be complex, and measuring their drift is not a straightforward task. Therefore, we calculate drift on different properties of the labels, on which we can directly measure drift.
Which Label Properties Are Used?#
Task Type |
Property name |
What is it |
---|---|---|
Classification |
Samples Per Class |
Number of images per class |
Object Detection |
Samples Per Class |
Number of bounding boxes per class |
Object Detection |
Bounding Box Area |
Area of bounding box (height * width) |
Object Detection |
Number of Bounding Boxes Per Image |
Number of bounding box objects in each image |
Run the check on a Classification task (MNIST)#
Imports#
from deepchecks.vision.checks import TrainTestLabelDrift
from deepchecks.vision.datasets.classification.mnist import load_dataset
Loading Data#
train_ds = load_dataset(train=True, batch_size=64, object_type='VisionData')
test_ds = load_dataset(train=False, batch_size=1000, object_type='VisionData')
Running TrainTestLabelDrift on classification#
check = TrainTestLabelDrift()
check.run(train_ds, test_ds)
Out:
Validating Input: 0%| | 0/1 [00:00<?, ? /s]
Ingesting Batches - Train Dataset: 0%| | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset: 18%|############################9 | 29/157 [00:00<00:00, 287.98 Batch/s]
Ingesting Batches - Train Dataset: 38%|########################################################### | 59/157 [00:00<00:00, 289.79 Batch/s]
Ingesting Batches - Train Dataset: 57%|######################################################################################### | 89/157 [00:00<00:00, 291.62 Batch/s]
Ingesting Batches - Train Dataset: 76%|####################################################################################################################### | 119/157 [00:00<00:00, 294.33 Batch/s]
Ingesting Batches - Train Dataset: 95%|##################################################################################################################################################### | 149/157 [00:00<00:00, 294.54 Batch/s]
Ingesting Batches - Test Dataset: 0%| | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset: 1%|## | 2/157 [00:00<00:08, 19.25 Batch/s]
Ingesting Batches - Test Dataset: 3%|#### | 4/157 [00:00<00:07, 19.29 Batch/s]
Ingesting Batches - Test Dataset: 4%|###### | 6/157 [00:00<00:07, 19.52 Batch/s]
Ingesting Batches - Test Dataset: 5%|######## | 8/157 [00:00<00:07, 19.62 Batch/s]
Ingesting Batches - Test Dataset: 6%|########## | 10/157 [00:00<00:07, 19.63 Batch/s]
Computing Check: 0%| | 0/1 [00:00<?, ? Check/s]
Understanding the results#
We can see there is almost no drift between the train & test labels. This means the split to train and test was good (as it is balanced and random). Let’s check the performance of a simple model trained on MNIST.
from deepchecks.vision.checks import ClassPerformance
from deepchecks.vision.datasets.classification.mnist import \
load_model as load_mnist_model
mnist_model = load_mnist_model(pretrained=True)
ClassPerformance().run(train_ds, test_ds, mnist_model)
Out:
Validating Input: 0%| | 0/1 [00:00<?, ? /s]
Validating Input: 100%|#| 1/1 [00:00<00:00, 8.25 /s]
Ingesting Batches - Train Dataset: 0%| | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset: 10%|################ | 16/157 [00:00<00:00, 155.35 Batch/s]
Ingesting Batches - Train Dataset: 21%|################################# | 33/157 [00:00<00:00, 159.94 Batch/s]
Ingesting Batches - Train Dataset: 32%|################################################## | 50/157 [00:00<00:00, 162.74 Batch/s]
Ingesting Batches - Train Dataset: 43%|################################################################### | 67/157 [00:00<00:00, 164.16 Batch/s]
Ingesting Batches - Train Dataset: 54%|#################################################################################### | 84/157 [00:00<00:00, 161.28 Batch/s]
Ingesting Batches - Train Dataset: 64%|##################################################################################################### | 101/157 [00:00<00:00, 161.07 Batch/s]
Ingesting Batches - Train Dataset: 75%|###################################################################################################################### | 118/157 [00:00<00:00, 162.39 Batch/s]
Ingesting Batches - Train Dataset: 86%|####################################################################################################################################### | 135/157 [00:00<00:00, 161.76 Batch/s]
Ingesting Batches - Train Dataset: 97%|######################################################################################################################################################## | 152/157 [00:00<00:00, 162.06 Batch/s]
Ingesting Batches - Test Dataset: 0%| | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset: 1%|## | 2/157 [00:00<00:14, 10.92 Batch/s]
Ingesting Batches - Test Dataset: 3%|#### | 4/157 [00:00<00:13, 11.18 Batch/s]
Ingesting Batches - Test Dataset: 4%|###### | 6/157 [00:00<00:13, 11.35 Batch/s]
Ingesting Batches - Test Dataset: 5%|######## | 8/157 [00:00<00:13, 11.29 Batch/s]
Ingesting Batches - Test Dataset: 6%|########## | 10/157 [00:00<00:13, 11.29 Batch/s]
Computing Check: 0%| | 0/1 [00:00<?, ? Check/s]
Computing Check: 100%|#| 1/1 [00:00<00:00, 5.80 Check/s]
MNIST with label drift#
Now, let’s try to separate the MNIST dataset in a different manner that will result in a label drift, and see how it affects the performance. We are going to create a custom collate_fn` in the test dataset, that will select samples with class 0 in a 1/10 chances.
import torch
mnist_dataloader_train = load_dataset(train=True, batch_size=64, object_type='DataLoader')
mnist_dataloader_test = load_dataset(train=False, batch_size=1000, object_type='DataLoader')
full_mnist = torch.utils.data.ConcatDataset([mnist_dataloader_train.dataset, mnist_dataloader_test.dataset])
train_dataset, test_dataset = torch.utils.data.random_split(full_mnist, [60000,10000], generator=torch.Generator().manual_seed(42))
Inserting drift to the test set#
import numpy as np
from torch.utils.data._utils.collate import default_collate
np.random.seed(42)
def collate_test(batch):
modified_batch = []
for item in batch:
image, label = item
if label == 0:
if np.random.randint(5) == 0:
modified_batch.append(item)
else:
modified_batch.append((image, 1))
else:
modified_batch.append(item)
return default_collate(modified_batch)
mod_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
mod_test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, collate_fn=collate_test)
from deepchecks.vision.datasets.classification.mnist import MNISTData
mod_train_ds = MNISTData(mod_train_loader)
mod_test_ds = MNISTData(mod_test_loader)
Run the check#
check = TrainTestLabelDrift()
check.run(mod_train_ds, mod_test_ds)
Out:
Validating Input: 0%| | 0/1 [00:00<?, ? /s]
Ingesting Batches - Train Dataset: 0%| | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset: 18%|############################9 | 29/157 [00:00<00:00, 282.26 Batch/s]
Ingesting Batches - Train Dataset: 37%|#########################################################9 | 58/157 [00:00<00:00, 282.48 Batch/s]
Ingesting Batches - Train Dataset: 55%|####################################################################################### | 87/157 [00:00<00:00, 283.77 Batch/s]
Ingesting Batches - Train Dataset: 74%|###################################################################################################################9 | 116/157 [00:00<00:00, 282.32 Batch/s]
Ingesting Batches - Train Dataset: 92%|################################################################################################################################################# | 145/157 [00:00<00:00, 279.52 Batch/s]
Ingesting Batches - Test Dataset: 0%| | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset: 18%|###########################9 | 28/157 [00:00<00:00, 278.55 Batch/s]
Ingesting Batches - Test Dataset: 36%|########################################################9 | 57/157 [00:00<00:00, 279.97 Batch/s]
Ingesting Batches - Test Dataset: 55%|###################################################################################### | 86/157 [00:00<00:00, 280.77 Batch/s]
Ingesting Batches - Test Dataset: 73%|##################################################################################################################9 | 115/157 [00:00<00:00, 281.13 Batch/s]
Ingesting Batches - Test Dataset: 92%|################################################################################################################################################ | 144/157 [00:00<00:00, 279.68 Batch/s]
Computing Check: 0%| | 0/1 [00:00<?, ? Check/s]
Add a condition#
We could also add a condition to the check to alert us to changes in the label distribution, such as the one that occurred here.
check = TrainTestLabelDrift().add_condition_drift_score_not_greater_than()
check.run(mod_train_ds, mod_test_ds)
# As we can see, the condition alerts us to the present of drift in the label.
Out:
Validating Input: 0%| | 0/1 [00:00<?, ? /s]
Ingesting Batches - Train Dataset: 0%| | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset: 18%|###########################9 | 28/157 [00:00<00:00, 277.49 Batch/s]
Ingesting Batches - Train Dataset: 36%|#######################################################9 | 56/157 [00:00<00:00, 274.00 Batch/s]
Ingesting Batches - Train Dataset: 54%|#################################################################################### | 84/157 [00:00<00:00, 273.49 Batch/s]
Ingesting Batches - Train Dataset: 71%|###############################################################################################################9 | 112/157 [00:00<00:00, 271.86 Batch/s]
Ingesting Batches - Train Dataset: 89%|############################################################################################################################################ | 140/157 [00:00<00:00, 271.54 Batch/s]
Ingesting Batches - Test Dataset: 0%| | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset: 17%|########################### | 27/157 [00:00<00:00, 267.96 Batch/s]
Ingesting Batches - Test Dataset: 35%|####################################################### | 55/157 [00:00<00:00, 274.01 Batch/s]
Ingesting Batches - Test Dataset: 53%|################################################################################### | 83/157 [00:00<00:00, 272.72 Batch/s]
Ingesting Batches - Test Dataset: 71%|##############################################################################################################9 | 111/157 [00:00<00:00, 271.28 Batch/s]
Ingesting Batches - Test Dataset: 89%|########################################################################################################################################### | 139/157 [00:00<00:00, 272.73 Batch/s]
Computing Check: 0%| | 0/1 [00:00<?, ? Check/s]
Results#
We can see the check successfully detects the (expected) drift in class 0 distribution between the train and test sets
But how does this affect the performance of the model?#
Out:
Validating Input: 0%| | 0/1 [00:00<?, ? /s]
Ingesting Batches - Train Dataset: 0%| | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset: 10%|################ | 16/157 [00:00<00:00, 153.71 Batch/s]
Ingesting Batches - Train Dataset: 20%|################################ | 32/157 [00:00<00:00, 155.98 Batch/s]
Ingesting Batches - Train Dataset: 31%|################################################ | 48/157 [00:00<00:00, 155.50 Batch/s]
Ingesting Batches - Train Dataset: 41%|################################################################ | 64/157 [00:00<00:00, 156.99 Batch/s]
Ingesting Batches - Train Dataset: 51%|################################################################################ | 80/157 [00:00<00:00, 157.35 Batch/s]
Ingesting Batches - Train Dataset: 61%|################################################################################################ | 96/157 [00:00<00:00, 158.07 Batch/s]
Ingesting Batches - Train Dataset: 71%|###############################################################################################################9 | 112/157 [00:00<00:00, 158.25 Batch/s]
Ingesting Batches - Train Dataset: 82%|################################################################################################################################ | 128/157 [00:00<00:00, 158.17 Batch/s]
Ingesting Batches - Train Dataset: 92%|################################################################################################################################################ | 144/157 [00:00<00:00, 157.96 Batch/s]
Ingesting Batches - Test Dataset: 0%| | 0/157 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset: 10%|################ | 16/157 [00:00<00:00, 157.81 Batch/s]
Ingesting Batches - Test Dataset: 20%|################################ | 32/157 [00:00<00:00, 156.64 Batch/s]
Ingesting Batches - Test Dataset: 31%|################################################ | 48/157 [00:00<00:00, 157.26 Batch/s]
Ingesting Batches - Test Dataset: 41%|################################################################ | 64/157 [00:00<00:00, 158.16 Batch/s]
Ingesting Batches - Test Dataset: 51%|################################################################################ | 80/157 [00:00<00:00, 158.63 Batch/s]
Ingesting Batches - Test Dataset: 61%|################################################################################################ | 96/157 [00:00<00:00, 157.75 Batch/s]
Ingesting Batches - Test Dataset: 71%|###############################################################################################################9 | 112/157 [00:00<00:00, 157.22 Batch/s]
Ingesting Batches - Test Dataset: 82%|################################################################################################################################ | 128/157 [00:00<00:00, 157.03 Batch/s]
Ingesting Batches - Test Dataset: 92%|################################################################################################################################################ | 144/157 [00:00<00:00, 157.48 Batch/s]
Computing Check: 0%| | 0/1 [00:00<?, ? Check/s]
Inferring the results#
We can see the drop in the precision of class 0, which was caused by the class imbalance indicated earlier by the label drift check.
Run the check on an Object Detection task (COCO)#
from deepchecks.vision.datasets.detection.coco import load_dataset
train_ds = load_dataset(train=True, object_type='VisionData')
test_ds = load_dataset(train=False, object_type='VisionData')
check = TrainTestLabelDrift()
check.run(train_ds, test_ds)
Out:
Validating Input: 0%| | 0/1 [00:00<?, ? /s]
Ingesting Batches - Train Dataset: 0%| | 0/2 [00:00<?, ? Batch/s]
Ingesting Batches - Train Dataset: 50%|# | 1/2 [00:00<00:00, 5.23 Batch/s]
Ingesting Batches - Train Dataset: 100%|##| 2/2 [00:00<00:00, 5.49 Batch/s]
Ingesting Batches - Test Dataset: 0%| | 0/2 [00:00<?, ? Batch/s]
Ingesting Batches - Test Dataset: 50%|# | 1/2 [00:00<00:00, 6.44 Batch/s]
Ingesting Batches - Test Dataset: 100%|##| 2/2 [00:00<00:00, 5.99 Batch/s]
Computing Check: 0%| | 0/1 [00:00<?, ? Check/s]
Computing Check: 100%|#| 1/1 [00:00<00:00, 9.61 Check/s]
Label drift is detected!#
We can see that the COCO128 contains a drift in the out of the box dataset. In addition to the label count per class, the label drift check for object detection tasks include drift calculation on certain measurements, like the bounding box area and the number of bboxes per image.
Total running time of the script: ( 0 minutes 49.322 seconds)