Note
Go to the end to download the full example code
Heatmap Comparison#
This notebooks provides an overview for using and understanding Heatmap comparison check.
Structure:
What Is a Heatmap Comparison?#
Heatmap comparison is a method of detecting data drift in image data. Data drift is simply a change in the distribution of data over time or between several distinct cases. It is also one of the top reasons that a machine learning model performance degrades over time, or when applied to new scenarios.
The Heatmap comparison check simply computes an average image for all images in each dataset, train and test, and visualizes both the average images of both. That way, we can visually compare the difference between the datasets’ brightness distribution. For example, if training data contains significantly more images with sky, we will see that the average train image is brighter in the upper half of the heatmap.
Comparing Labels for Object Detection#
For object detection tasks, it is also possible to visualize Label Drift, by displaying the average of bounding box label coverage. This is done by producing label maps per image, in which each pixel inside a bounding box is white and the rest and black. Then, the average of all these images is displayed.
In our previous example, the drift caused by more images with sky in training would also be visible by a lack of labels in the upper half of the average label map of the training data, due to lack of labels in the sky.
Other Methods of Drift Detection#
Another, more traditional method to detect such drift would be to use statistical methods. Such an approach is covered by several builtin check in the deepchecks.vision package, such as the Label Drift Check or the Image Dataset Drift Check.
Run the Check on a Classification Task (MNIST)#
Imports#
Note
In this example, we use the pytorch version of the mnist dataset and model. In order to run this example using tensorflow, please change the import statements to:
from deepchecks.vision.datasets.classification.mnist_tensorflow import load_dataset
from deepchecks.vision.datasets.classification.mnist_torch import load_dataset
Loading Data#
mnist_data_train = load_dataset(train=True, batch_size=64, object_type='VisionData')
mnist_data_test = load_dataset(train=False, batch_size=64, object_type='VisionData')
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
13%|█▎ | 1277952/9912422 [00:00<00:00, 12422866.80it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 61064062.55it/s]
Extracting /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data/train-images-idx3-ubyte.gz to /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
100%|██████████| 28881/28881 [00:00<00:00, 1668535.73it/s]
Extracting /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data/train-labels-idx1-ubyte.gz to /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
93%|█████████▎| 1540096/1648877 [00:00<00:00, 15363100.80it/s]
100%|██████████| 1648877/1648877 [00:00<00:00, 16259716.12it/s]
Extracting /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data/t10k-images-idx3-ubyte.gz to /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
100%|██████████| 4542/4542 [00:00<00:00, 16451233.82it/s]
Extracting /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data/t10k-labels-idx1-ubyte.gz to /home/runner/work/deepchecks/deepchecks/deepchecks/vision/datasets/assets/mnist/raw_data
You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
from deepchecks.vision.checks import HeatmapComparison
check = HeatmapComparison()
result = check.run(mnist_data_train, mnist_data_test)
result
Processing Train Batches:
| | 0/1 [Time: 00:00]
Processing Train Batches:
|█████| 1/1 [Time: 00:01]
Processing Train Batches:
|█████| 1/1 [Time: 00:01]
Processing Test Batches:
| | 0/1 [Time: 00:00]
Processing Test Batches:
|█████| 1/1 [Time: 00:05]
Processing Test Batches:
|█████| 1/1 [Time: 00:05]
Computing Check:
| | 0/1 [Time: 00:00]
Computing Check:
|█████| 1/1 [Time: 00:00]
To display the results in an IDE like PyCharm, you can use the following code:
# result.show_in_window()
The result will be displayed in a new window.
Run the Check on an Object Detection Task (Coco)#
Note
In this example, we use the pytorch version of the coco dataset and model. In order to run this example using tensorflow, please change the import statements to:
from deepchecks.vision.datasets.detection.coco_tensorflow import load_dataset
from deepchecks.vision.datasets.detection.coco_torch import load_dataset
train_ds = load_dataset(train=True, object_type='VisionData')
test_ds = load_dataset(train=False, object_type='VisionData')
You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
check = HeatmapComparison()
result = check.run(train_ds, test_ds)
result
Processing Train Batches:
| | 0/1 [Time: 00:00]
Processing Train Batches:
|█████| 1/1 [Time: 00:00]
Processing Train Batches:
|█████| 1/1 [Time: 00:00]
Processing Test Batches:
| | 0/1 [Time: 00:00]
Processing Test Batches:
|█████| 1/1 [Time: 00:00]
Processing Test Batches:
|█████| 1/1 [Time: 00:00]
Computing Check:
| | 0/1 [Time: 00:00]
Computing Check:
|█████| 1/1 [Time: 00:00]
Limit to Specific Classes#
The check can be limited to compare the bounding box coverage for a specific set of classes. We’ll use that to inspect only objects labeled as human (class_id 0)
check = HeatmapComparison(classes_to_display=['person'])
result = check.run(train_ds, test_ds)
result
Processing Train Batches:
| | 0/1 [Time: 00:00]
Processing Train Batches:
|█████| 1/1 [Time: 00:00]
Processing Train Batches:
|█████| 1/1 [Time: 00:00]
Processing Test Batches:
| | 0/1 [Time: 00:00]
Processing Test Batches:
|█████| 1/1 [Time: 00:00]
Processing Test Batches:
|█████| 1/1 [Time: 00:00]
Computing Check:
| | 0/1 [Time: 00:00]
Computing Check:
|█████| 1/1 [Time: 00:00]
We can see a significant increased abundance of humans in the test data, located in the images lower center!
Total running time of the script: (0 minutes 13.295 seconds)