.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "user-guide/vision/auto_quickstarts/plot_detection_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_user-guide_vision_auto_quickstarts_plot_detection_tutorial.py: .. _vision_detection_tutorial: ========================== Object Detection Tutorial ========================== In this tutorial, you will learn how to validate your **object detection model** using deepchecks test suites. You can read more about the different checks and suites for computer vision use cases at the :doc:`examples section ` If you just want to see the output of this tutorial, jump to :ref:`observing_the_result` section. An object detection tasks usually consists of two parts: - Object Localization, where the model predicts the location of an object in the image, - Object Classification, where the model predicts the class of the detected object. The common output of an object detection model is a list of bounding boxes around the objects, and their classes. .. GENERATED FROM PYTHON SOURCE LINES 24-26 Defining the data and model =========================== .. GENERATED FROM PYTHON SOURCE LINES 26-48 .. code-block:: default # Importing the required packages import os import urllib.request import xml.etree.ElementTree as ET import zipfile from functools import partial import albumentations as A import matplotlib.pyplot as plt import numpy as np import torch import torchvision from albumentations.pytorch import ToTensorV2 from PIL import Image from torch import nn from torch.utils.data import DataLoader, Dataset from torchvision.models.detection import _utils as det_utils from torchvision.models.detection.ssdlite import SSDLiteClassificationHead from deepchecks.vision.detection_data import DetectionData .. GENERATED FROM PYTHON SOURCE LINES 49-60 Load Data ~~~~~~~~~ The model in this tutorial is used to detect tomatoes in images. The model is trained on a dataset consisted of 895 images of tomatoes, with bounding box annotations provided in PASCAL VOC format. All annotations belong to a single class: tomato. .. note:: The dataset is available at the following link: https://www.kaggle.com/andrewmvd/tomato-detection We thank the authors of the dataset for providing the dataset. .. GENERATED FROM PYTHON SOURCE LINES 60-133 .. code-block:: default url = 'https://figshare.com/ndownloader/files/34488599' urllib.request.urlretrieve(url, 'tomato-detection.zip') with zipfile.ZipFile('tomato-detection.zip', 'r') as zip_ref: zip_ref.extractall('.') class TomatoDataset(Dataset): def __init__(self, root, transforms): self.root = root self.transforms = transforms self.images = list(sorted(os.listdir(os.path.join(root, 'images')))) self.annotations = list(sorted(os.listdir(os.path.join(root, 'annotations')))) def __getitem__(self, idx): img_path = os.path.join(self.root, "images", self.images[idx]) ann_path = os.path.join(self.root, "annotations", self.annotations[idx]) img = Image.open(img_path).convert("RGB") bboxes = [] labels = [] with open(ann_path, 'r') as f: tree = ET.parse(f) root = tree.getroot() size = root.find('size') w = int(size.find('width').text) h = int(size.find('height').text) for obj in root.iter('object'): difficult = obj.find('difficult').text if int(difficult) == 1: continue cls_id = 1 xmlbox = obj.find('bndbox') b = [float(xmlbox.find('xmin').text), float(xmlbox.find('ymin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymax').text)] bboxes.append(b) labels.append(cls_id) bboxes = torch.as_tensor(np.array(bboxes), dtype=torch.float32) labels = torch.as_tensor(np.array(labels), dtype=torch.int64) if self.transforms is not None: res = self.transforms(image=np.array(img), bboxes=bboxes, class_labels=labels) target = { 'boxes': [torch.Tensor(x) for x in res['bboxes']], 'labels': res['class_labels'] } img = res['image'] return img, target def __len__(self): return len(self.images) data_transforms = A.Compose([ A.Resize(height=256, width=256), A.CenterCrop(height=224, width=224), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels'])) dataset = TomatoDataset(root=os.path.join(os.path.curdir, 'tomato-detection/data'), transforms=data_transforms) train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.9), len(dataset)-int(len(dataset)*0.9)], generator=torch.Generator().manual_seed(42)) val_set.transforms = A.Compose([ToTensorV2()]) train_loader = DataLoader(train_set, batch_size=64, collate_fn=(lambda batch: tuple(zip(*batch)))) val_loader = DataLoader(val_set, batch_size=64, collate_fn=(lambda batch: tuple(zip(*batch)))) .. GENERATED FROM PYTHON SOURCE LINES 134-137 Visualize a Few Images ~~~~~~~~~~~~~~~~~~~~~~ Let's visualize a few training images so as to understand the data augmentation. .. GENERATED FROM PYTHON SOURCE LINES 137-172 .. code-block:: default def prepare(inp): """Imshow for Tensor.""" inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1) * 255 inp = inp.transpose((2,0,1)) return torch.tensor(inp, dtype=torch.uint8) import torchvision.transforms.functional as F def show(imgs): if not isinstance(imgs, list): imgs = [imgs] fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(20,20)) for i, img in enumerate(imgs): img = img.detach() img = F.to_pil_image(img) axs[0, i].imshow(np.asarray(img)) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) from torchvision.utils import draw_bounding_boxes data = next(iter(train_loader)) inp, targets = data[0][:4], data[1][:4] result = [draw_bounding_boxes(prepare(inp[i]), torch.stack(targets[i]['boxes']), colors=['yellow'] * torch.stack(targets[i]['boxes']).shape[0], width=5) for i in range(len(targets))] show(result) .. image-sg:: /user-guide/vision/auto_quickstarts/images/sphx_glr_plot_detection_tutorial_001.png :alt: plot detection tutorial :srcset: /user-guide/vision/auto_quickstarts/images/sphx_glr_plot_detection_tutorial_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 173-184 .. image :: /_static/images/tutorials/tomatoes.png :alt: Tomatoes with bbox Downloading a Pre-trained Model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In this tutorial, we will download a pre-trained SSDlite model and a MobileNetV3 Large backbone from the official PyTorch repository. For more details, please refer to the `official documentation `_. After downloading the model, we will fine-tune it for our particular classes. We will do it by replacing the pre-trained head with a new one that matches our needs. .. GENERATED FROM PYTHON SOURCE LINES 184-196 .. code-block:: default device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True) in_channels = det_utils.retrieve_out_channels(model.backbone, (320, 320)) num_anchors = model.anchor_generator.num_anchors_per_location() norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) model.head.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, 2, norm_layer) model.to(device) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth" to /home/runner/.cache/torch/hub/checkpoints/ssdlite320_mobilenet_v3_large_coco-a79551df.pth 0%| | 0.00/13.4M [00:00 First element is: with len of 64 Example output of an image shape from the dataloader torch.Size([3, 224, 224]) Image values tensor([[[-1.79253, -1.82678, -1.82678, ..., 1.39267, 1.34130, 1.32417], [-1.72403, -1.79253, -1.80966, ..., 1.35842, 1.32417, 1.34130], [-1.75828, -1.74116, -1.70691, ..., 1.32417, 1.34130, 1.35842], ..., [-1.84391, -1.82678, -1.75828, ..., 0.62206, 0.19394, -0.35405], [-1.80966, -1.79253, -1.72403, ..., 0.81043, 0.72481, 0.34806], [-1.79253, -1.84391, -1.75828, ..., 0.81043, 0.82755, 0.69056]], [[-1.38796, -1.45798, -1.45798, ..., 1.51821, 1.46569, 1.46569], [-1.38796, -1.47549, -1.52801, ..., 1.50070, 1.46569, 1.48319], [-1.42297, -1.47549, -1.49300, ..., 1.46569, 1.50070, 1.50070], ..., [-1.70308, -1.68557, -1.61555, ..., 0.67787, 0.22269, -0.33753], [-1.68557, -1.66807, -1.58053, ..., 0.87045, 0.74790, 0.38025], [-1.68557, -1.70308, -1.61555, ..., 0.87045, 0.85294, 0.71289]], [[-1.57786, -1.61272, -1.61272, ..., 1.66397, 1.61168, 1.59425], [-1.54301, -1.59529, -1.64758, ..., 1.62911, 1.59425, 1.59425], [-1.59529, -1.59529, -1.61272, ..., 1.59425, 1.61168, 1.62911], ..., [-1.63015, -1.59529, -1.52558, ..., 0.46135, 0.02562, -0.53211], [-1.59529, -1.54301, -1.47329, ..., 0.72279, 0.61821, 0.21734], [-1.59529, -1.59529, -1.50815, ..., 0.72279, 0.72279, 0.56593]]]) -------------------------------------------------------------------------------- Second element is: with len of 64 Example output of a label shape from the dataloader {'boxes': [tensor([ 0.00000, 75.13600, 39.68000, 165.75999]), tensor([ 0.00000, 0.00000, 94.08000, 93.56800])], 'labels': [tensor(1), tensor(1)]} .. GENERATED FROM PYTHON SOURCE LINES 226-239 Implementing the DetectionData class ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The checks in the package validate the model & data by calculating various quantities over the data, labels and predictions. In order to do that, those must be in a pre-defined format, according to the task type. The first step is to implement a class that enables deepchecks to interact with your model and data and transform them to this pre-defined format, which is set for each task type. In this tutorial, we will implement the object detection task type by implementing a class that inherits from the :class:`deepchecks.vision.detection_data.DetectionData` class. The DetectionData class contains additional data and general methods intended for easy access to relevant metadata for object detection ML models validation. To learn more about the expected format please visit the API reference for the :class:`deepchecks.vision.detection_data.DetectionData` class. .. GENERATED FROM PYTHON SOURCE LINES 239-312 .. code-block:: default from deepchecks.vision.detection_data import DetectionData class TomatoData(DetectionData): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def batch_to_images(self, batch): """ Convert a batch of data to images in the expected format. The expected format is an iterable of cv2 images, where each image is a numpy array of shape (height, width, channels). The numbers in the array should be in the range [0, 255] in a uint8 format. """ inp = torch.stack(list(batch[0])).cpu().detach().numpy().transpose((0, 2, 3, 1)) mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] # Un-normalize the images inp = std * inp + mean inp = np.clip(inp, 0, 1) return inp * 255 def batch_to_labels(self, batch): """ Convert a batch of data to labels in the expected format. The expected format is a list of tensors of length N, where N is the number of samples. Each tensor element is in a shape of [B, 5], where B is the number of bboxes in the image, and each bounding box is in the structure of [class_id, x, y, w, h]. """ tensor_annotations = batch[1] label = [] for annotation in tensor_annotations: if len(annotation["boxes"]): bbox = torch.stack(annotation["boxes"]) # Convert the Pascal VOC xyxy format to xywh format bbox[:, 2:] = bbox[:, 2:] - bbox[:, :2] # The label shape is [class_id, x, y, w, h] label.append( torch.concat([torch.stack(annotation["labels"]).reshape((-1, 1)), bbox], dim=1) ) else: # If it's an empty image, we need to add an empty label label.append(torch.tensor([])) return label def infer_on_batch(self, batch, model, device): """ Returns the predictions for a batch of data. The expected format is a list of tensors of shape length N, where N is the number of samples. Each tensor element is in a shape of [B, 6], where B is the number of bboxes in the predictions, and each bounding box is in the structure of [x, y, w, h, score, class_id]. """ nm_thrs = 0.2 score_thrs = 0.7 imgs = list(img.to(device) for img in batch[0]) # Getting the predictions of the model on the batch with torch.no_grad(): preds = model(imgs) processed_pred = [] for pred in preds: # Performoing non-maximum suppression on the detections keep_boxes = torchvision.ops.nms(pred['boxes'], pred['scores'], nm_thrs) score_filter = pred['scores'][keep_boxes] > score_thrs # get the filtered result test_boxes = pred['boxes'][keep_boxes][score_filter].reshape((-1, 4)) test_boxes[:, 2:] = test_boxes[:, 2:] - test_boxes[:, :2] # xyxy to xywh test_labels = pred['labels'][keep_boxes][score_filter] test_scores = pred['scores'][keep_boxes][score_filter] processed_pred.append( torch.concat([test_boxes, test_scores.reshape((-1, 1)), test_labels.reshape((-1, 1))], dim=1)) return processed_pred .. GENERATED FROM PYTHON SOURCE LINES 313-314 After defining the task class, we can validate it by running the following code: .. GENERATED FROM PYTHON SOURCE LINES 314-328 .. code-block:: default # We have a single label here, which is the tomato class # The label_map is a dictionary that maps the class id to the class name, for display purposes. LABEL_MAP = { 1: 'Tomato' } training_data = TomatoData(data_loader=train_loader, label_map=LABEL_MAP) val_data = TomatoData(data_loader=val_loader, label_map=LABEL_MAP) training_data.validate_format(model, device=device) val_data.validate_format(model, device=device) # And observe the output: .. rst-class:: sphx-glr-script-out .. code-block:: none Deepchecks will try to validate the extractors given... torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2157.) Structure validation -------------------- Label formatter: Pass! Prediction formatter: Pass! Image formatter: Pass! Content validation ------------------ For validating the content within the structure you have to manually observe the classes, image, label and prediction. Examples of classes observed in the batch's labels: [[1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1], [1]] Visual images & label & prediction: should open in a new window ******************************************************************************* This machine does not support GUI The formatted image was saved in: /home/runner/work/deepchecks/deepchecks/docs/source/user-guide/vision/quickstarts/deepchecks_formatted_image (2).jpg Visual examples of an image with prediction and label data. Label is red, prediction is blue, and deepchecks loves you. validate_extractors can be set to skip the image saving or change the save path ******************************************************************************* Deepchecks will try to validate the extractors given... Structure validation -------------------- Label formatter: Pass! Prediction formatter: Pass! Image formatter: Pass! Content validation ------------------ For validating the content within the structure you have to manually observe the classes, image, label and prediction. Examples of classes observed in the batch's labels: [[1, 1, 1, 1], [1], [1, 1, 1, 1, 1, 1, 1], [1], [1]] Visual images & label & prediction: should open in a new window ******************************************************************************* This machine does not support GUI The formatted image was saved in: /home/runner/work/deepchecks/deepchecks/docs/source/user-guide/vision/quickstarts/deepchecks_formatted_image (3).jpg Visual examples of an image with prediction and label data. Label is red, prediction is blue, and deepchecks loves you. validate_extractors can be set to skip the image saving or change the save path ******************************************************************************* .. GENERATED FROM PYTHON SOURCE LINES 329-333 Running Deepchecks' full suite on our data and model! ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Now that we have defined the task class, we can validate the model with the full suite of deepchecks. This can be done with this simple few lines of code: .. GENERATED FROM PYTHON SOURCE LINES 333-339 .. code-block:: default from deepchecks.vision.suites import full_suite suite = full_suite() result = suite.run(training_data, val_data, model, device=device) .. rst-class:: sphx-glr-script-out .. code-block:: none Validating Input: | | 0/1 [Time: 00:00] Validating Input: |#####| 1/1 [Time: 00:06] Validating Input: |#####| 1/1 [Time: 00:06]Properties that have class_id as output_type will be skipped. Ingesting Batches - Train Dataset: | | 0/13 [Time: 00:00] Ingesting Batches - Train Dataset: |# | 1/13 [Time: 00:04] Ingesting Batches - Train Dataset: |## | 2/13 [Time: 00:08] Ingesting Batches - Train Dataset: |### | 3/13 [Time: 00:13] Ingesting Batches - Train Dataset: |#### | 4/13 [Time: 00:17] Ingesting Batches - Train Dataset: |##### | 5/13 [Time: 00:22] Ingesting Batches - Train Dataset: |###### | 6/13 [Time: 00:26] Ingesting Batches - Train Dataset: |####### | 7/13 [Time: 00:31] Ingesting Batches - Train Dataset: |######## | 8/13 [Time: 00:35] Ingesting Batches - Train Dataset: |######### | 9/13 [Time: 00:40] Ingesting Batches - Train Dataset: |########## | 10/13 [Time: 00:44] Ingesting Batches - Train Dataset: |########### | 11/13 [Time: 00:49] Ingesting Batches - Train Dataset: |############ | 12/13 [Time: 00:53] Ingesting Batches - Train Dataset: |#############| 13/13 [Time: 00:56] Ingesting Batches - Train Dataset: |#############| 13/13 [Time: 00:56] Computing Single Dataset Checks - Train Dataset: | | 0/6 [Time: 00:00] Computing Single Dataset Checks - Train Dataset: |# | 1/6 [Time: 00:00, Check=Mean Average Precision Report] Computing Single Dataset Checks - Train Dataset: |## | 2/6 [Time: 00:00, Check=Mean Average Recall Report] Computing Single Dataset Checks - Train Dataset: |#### | 4/6 [Time: 00:07, Check=Image Segment Performance] Computing Single Dataset Checks - Train Dataset: |##### | 5/6 [Time: 00:07, Check=Image Property Outliers] Computing Single Dataset Checks - Train Dataset: |######| 6/6 [Time: 00:07, Check=Label Property Outliers] Computing Single Dataset Checks - Train Dataset: |######| 6/6 [Time: 00:07, Check=Label Property Outliers] Ingesting Batches - Test Dataset: | | 0/2 [Time: 00:00] Ingesting Batches - Test Dataset: |##5 | 1/2 [Time: 00:04] Ingesting Batches - Test Dataset: |#####| 2/2 [Time: 00:05] Ingesting Batches - Test Dataset: |#####| 2/2 [Time: 00:05] Computing Single Dataset Checks - Test Dataset: | | 0/6 [Time: 00:00] Computing Single Dataset Checks - Test Dataset: |## | 2/6 [Time: 00:00, Check=Mean Average Recall Report] Computing Single Dataset Checks - Test Dataset: |#### | 4/6 [Time: 00:00, Check=Image Segment Performance] Computing Single Dataset Checks - Test Dataset: |######| 6/6 [Time: 00:01, Check=Label Property Outliers] Computing Single Dataset Checks - Test Dataset: |######| 6/6 [Time: 00:01, Check=Label Property Outliers] Computing Checks: | | 0/10 [Time: 00:00] Computing Checks: | | 0/10 [Time: 00:00, Check=Class Performance] Computing Checks: |# | 1/10 [Time: 00:00, Check=Class Performance] Computing Checks: |# | 1/10 [Time: 00:00, Check=Train Test Prediction Drift] Computing Checks: |## | 2/10 [Time: 00:00, Check=Train Test Prediction Drift] Computing Checks: |## | 2/10 [Time: 00:00, Check=Simple Model Comparison] Computing Checks: |## | 2/10 [Time: 00:00, Check=Model Error Analysis] Default parameter min_samples_leaf will change in version 2.6.See https://github.com/scikit-learn-contrib/category_encoders/issues/327 Default parameter smoothing will change in version 2.6.See https://github.com/scikit-learn-contrib/category_encoders/issues/327 Computing Checks: |#### | 4/10 [Time: 00:00, Check=Model Error Analysis] Computing Checks: |#### | 4/10 [Time: 00:00, Check=New Labels] Computing Checks: |#### | 4/10 [Time: 00:00, Check=Heatmap Comparison] Computing Checks: |#### | 4/10 [Time: 00:00, Check=Train Test Label Drift] Computing Checks: |####### | 7/10 [Time: 00:00, Check=Train Test Label Drift] Computing Checks: |####### | 7/10 [Time: 00:00, Check=Image Property Drift] Computing Checks: |####### | 7/10 [Time: 00:00, Check=Image Dataset Drift] Computing Checks: |######### | 9/10 [Time: 00:01, Check=Image Dataset Drift] Computing Checks: |######### | 9/10 [Time: 00:01, Check=Property Label Correlation Change] Computing Checks: |##########| 10/10 [Time: 00:01, Check=Property Label Correlation Change] .. GENERATED FROM PYTHON SOURCE LINES 340-345 .. _observing_the_result: Observing the results: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The results can be saved as a html file with the following code: .. GENERATED FROM PYTHON SOURCE LINES 345-348 .. code-block:: default result.save_as_html('output.html') .. rst-class:: sphx-glr-script-out .. code-block:: none 'output (2).html' .. GENERATED FROM PYTHON SOURCE LINES 349-350 Or, if working inside a notebook, the output can be displayed directly by simply printing the result object: .. GENERATED FROM PYTHON SOURCE LINES 350-352 .. code-block:: default result .. raw:: html
Full Suite


.. rst-class:: sphx-glr-timing **Total running time of the script:** ( 1 minutes 47.102 seconds) .. _sphx_glr_download_user-guide_vision_auto_quickstarts_plot_detection_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_detection_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_detection_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_