.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "vision/auto_tutorials/quickstarts/plot_detection_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_vision_auto_tutorials_quickstarts_plot_detection_tutorial.py: .. _vision__detection_tutorial: ========================== Object Detection Tutorial ========================== In this tutorial, you will learn how to validate your **object detection model** using deepchecks test suites. You can read more about the different checks and suites for computer vision use cases at the :ref:`examples section `. If you just want to see the output of this tutorial, jump to the :ref:`observing the results ` section. An object detection tasks usually consist of two parts: - Object Localization, where the model predicts the location of an object in the image, - Object Classification, where the model predicts the class of the detected object. The common output of an object detection model is a list of bounding boxes around the objects, and their classes. .. code-block:: bash # Before we start, if you don't have deepchecks vision package installed yet, run: import sys !{sys.executable} -m pip install "deepchecks[vision]" --quiet --upgrade # --user # or install using pip from your python environment .. GENERATED FROM PYTHON SOURCE LINES 32-49 Defining the data and model =========================== .. note:: In this tutorial, we use the pytorch to create the dataset and model. To see how this can be done using tensorflow or other frameworks, please visit the :ref:`vision__vision_data_class` guide. Load Data ~~~~~~~~~ The model in this tutorial is used to detect tomatoes in images. The model is trained on a dataset consisted of 895 images of tomatoes, with bounding box annotations provided in PASCAL VOC format. All annotations belong to a single class: tomato. .. note:: The dataset is available at the following link: https://www.kaggle.com/andrewmvd/tomato-detection We thank the authors of the dataset for providing the dataset. .. GENERATED FROM PYTHON SOURCE LINES 49-125 .. code-block:: default import os import numpy as np import torch from torch.utils.data import DataLoader, Dataset import albumentations as A from albumentations.pytorch import ToTensorV2 from PIL import Image import xml.etree.ElementTree as ET import urllib.request import zipfile url = 'https://figshare.com/ndownloader/files/34488599' urllib.request.urlretrieve(url, './tomato-detection.zip') with zipfile.ZipFile('./tomato-detection.zip', 'r') as zip_ref: zip_ref.extractall('.') class TomatoDataset(Dataset): def __init__(self, root, transforms): self.root = root self.transforms = transforms self.images = list(sorted(os.listdir(os.path.join(root, 'images')))) self.annotations = list(sorted(os.listdir(os.path.join(root, 'annotations')))) def __getitem__(self, idx): img_path = os.path.join(self.root, "images", self.images[idx]) ann_path = os.path.join(self.root, "annotations", self.annotations[idx]) img = Image.open(img_path).convert("RGB") bboxes, labels = [], [] with open(ann_path, 'r') as f: root = ET.parse(f).getroot() for obj in root.iter('object'): difficult = obj.find('difficult').text if int(difficult) == 1: continue cls_id = 1 xmlbox = obj.find('bndbox') b = [float(xmlbox.find('xmin').text), float(xmlbox.find('ymin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymax').text)] bboxes.append(b) labels.append(cls_id) bboxes = torch.as_tensor(np.array(bboxes), dtype=torch.float32) labels = torch.as_tensor(np.array(labels), dtype=torch.int64) if self.transforms is not None: res = self.transforms(image=np.array(img), bboxes=bboxes, class_labels=labels) target = { 'boxes': [torch.Tensor(x) for x in res['bboxes']], 'labels': res['class_labels'] } img = res['image'] return img, target def __len__(self): return len(self.images) data_transforms = A.Compose([ A.Resize(height=256, width=256), A.CenterCrop(height=224, width=224), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels'])) dataset = TomatoDataset(root='./tomato-detection/data', transforms=data_transforms) train_dataset, test_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*0.9), len(dataset)-int(len(dataset)*0.9)], generator=torch.Generator().manual_seed(42)) test_dataset.transforms = A.Compose([ToTensorV2()]) .. GENERATED FROM PYTHON SOURCE LINES 126-129 Visualize the dataset ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Let's see how our data looks like. .. GENERATED FROM PYTHON SOURCE LINES 129-136 .. code-block:: default print(f'Number of training images: {len(train_dataset)}') print(f'Number of test images: {len(test_dataset)}') print(f'Example output of an image shape: {train_dataset[0][0].shape}') print(f'Example output of a label: {train_dataset[0][1]}') .. rst-class:: sphx-glr-script-out .. code-block:: none Number of training images: 805 Number of test images: 90 Example output of an image shape: torch.Size([3, 224, 224]) Example output of a label: {'boxes': [tensor([ 0.00000, 75.13600, 39.68000, 165.75999]), tensor([ 0.00000, 0.00000, 94.08000, 93.56800])], 'labels': [tensor(1), tensor(1)]} .. GENERATED FROM PYTHON SOURCE LINES 137-145 Downloading a Pre-trained Model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In this tutorial, we will download a pre-trained SSDlite model and a MobileNetV3 Large backbone from the official PyTorch repository. For more details, please refer to the `official documentation `_. After downloading the model, we will fine-tune it for our particular classes. We will do it by replacing the pre-trained head with a new one that matches our needs. .. GENERATED FROM PYTHON SOURCE LINES 145-162 .. code-block:: default from functools import partial from torch import nn import torchvision from torchvision.models.detection import _utils as det_utils from torchvision.models.detection.ssdlite import SSDLiteClassificationHead device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True) in_channels = det_utils.retrieve_out_channels(model.backbone, (320, 320)) num_anchors = model.anchor_generator.num_anchors_per_location() norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) model.head.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, 2, norm_layer) _ = model.to(device) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth" to /home/runner/.cache/torch/hub/checkpoints/ssdlite320_mobilenet_v3_large_coco-a79551df.pth 0%| | 0.00/13.4M [00:00 score_thrs # get the filtered result test_boxes = pred['boxes'][keep_boxes][score_filter].reshape((-1, 4)) test_boxes[:, 2:] = test_boxes[:, 2:] - test_boxes[:, :2] # xyxy to xywh test_labels = pred['labels'][keep_boxes][score_filter] test_scores = pred['scores'][keep_boxes][score_filter] processed_pred.append( torch.concat([test_boxes, test_scores.reshape((-1, 1)), test_labels.reshape((-1, 1))], dim=1)) return processed_pred .. GENERATED FROM PYTHON SOURCE LINES 257-260 Now we'll create the collate function that will be used by the DataLoader. In pytorch, the collate function is used to transform the output batch to any custom format, and we'll use that in order to transform the batch to the correct format for the checks. .. GENERATED FROM PYTHON SOURCE LINES 260-272 .. code-block:: default from deepchecks.vision.vision_data import BatchOutputFormat def deepchecks_collate_fn(batch) -> BatchOutputFormat: """Return a batch of images, labels and predictions in the deepchecks format.""" # batch received as iterable of tuples of (image, label) and transformed to tuple of iterables of images and labels: batch = tuple(zip(*batch)) images = get_untransformed_images(batch[0]) labels = transform_labels_to_cxywh(batch[1]) predictions = infer_on_images(batch[0]) return BatchOutputFormat(images=images, labels=labels, predictions=predictions) .. GENERATED FROM PYTHON SOURCE LINES 273-275 We have a single label here, which is the tomato class The label_map is a dictionary that maps the class id to the class name, for display purposes. .. GENERATED FROM PYTHON SOURCE LINES 275-280 .. code-block:: default LABEL_MAP = { 1: 'Tomato' } .. GENERATED FROM PYTHON SOURCE LINES 281-283 Now that we have our updated collate function, we can recreate the dataloader in the deepchecks format, and use it to create a VisionData object: .. GENERATED FROM PYTHON SOURCE LINES 283-292 .. code-block:: default from deepchecks.vision.vision_data import VisionData train_loader = DataLoader(train_dataset, batch_size=64, collate_fn=deepchecks_collate_fn) test_loader = DataLoader(test_dataset, batch_size=64, collate_fn=deepchecks_collate_fn) training_data = VisionData(batch_loader=train_loader, task_type='object_detection', label_map=LABEL_MAP) test_data = VisionData(batch_loader=test_loader, task_type='object_detection', label_map=LABEL_MAP) .. rst-class:: sphx-glr-script-out .. code-block:: none torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2157.) .. GENERATED FROM PYTHON SOURCE LINES 293-299 Making sure our data is in the correct format: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The VisionData object automatically validates your data format and will alert you if there is a problem. However, you can also manually view your images and labels to make sure they are in the correct format by using the ``head`` function to conveniently visualize your data: .. GENERATED FROM PYTHON SOURCE LINES 299-302 .. code-block:: default training_data.head() .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 303-307 Running Deepchecks' suite on our data and model! ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Now that we have defined the task class, we can validate the model with the deepchecks' model evaluation suite. This can be done with this simple few lines of code: .. GENERATED FROM PYTHON SOURCE LINES 307-313 .. code-block:: default from deepchecks.vision.suites import model_evaluation suite = model_evaluation() result = suite.run(training_data, test_data) .. rst-class:: sphx-glr-script-out .. code-block:: none Processing Batches:Train: | | 0/1 [Time: 00:00] Processing Batches:Train: |█████| 1/1 [Time: 00:58] Processing Batches:Train: |█████| 1/1 [Time: 00:58] Computing Single Dataset Checks Train: | | 0/4 [Time: 00:00] Computing Single Dataset Checks Train: |█▎ | 1/4 [Time: 00:00, Check=Mean Average Precision Report] Computing Single Dataset Checks Train: |██▌ | 2/4 [Time: 00:00, Check=Mean Average Recall Report] Computing Single Dataset Checks Train: |█████| 4/4 [Time: 00:01, Check=Weak Segments Performance] Computing Single Dataset Checks Train: |█████| 4/4 [Time: 00:01, Check=Weak Segments Performance] Processing Batches:Test: | | 0/1 [Time: 00:00] Processing Batches:Test: |█████| 1/1 [Time: 00:06] Processing Batches:Test: |█████| 1/1 [Time: 00:06] Computing Single Dataset Checks Test: | | 0/4 [Time: 00:00] Computing Single Dataset Checks Test: |██▌ | 2/4 [Time: 00:00, Check=Mean Average Recall Report] Computing Single Dataset Checks Test: |█████| 4/4 [Time: 00:01, Check=Weak Segments Performance] Computing Single Dataset Checks Test: |█████| 4/4 [Time: 00:01, Check=Weak Segments Performance] Computing Train Test Checks: | | 0/2 [Time: 00:00] Computing Train Test Checks: | | 0/2 [Time: 00:00, Check=Class Performance] Computing Train Test Checks: |██▌ | 1/2 [Time: 00:00, Check=Class Performance] Computing Train Test Checks: |██▌ | 1/2 [Time: 00:00, Check=Prediction Drift] Computing Train Test Checks: |█████| 2/2 [Time: 00:00, Check=Prediction Drift] Computing Train Test Checks: |█████| 2/2 [Time: 00:00, Check=Prediction Drift] .. GENERATED FROM PYTHON SOURCE LINES 314-319 We also have suites for: :func:`data integrity ` - validating a single dataset and :func:`train test validation ` - validating the dataset split .. GENERATED FROM PYTHON SOURCE LINES 321-326 .. _observing_the_result: Observing the results: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The results can be saved as a html file with the following code: .. GENERATED FROM PYTHON SOURCE LINES 326-329 .. code-block:: default result.save_as_html('output.html') .. rst-class:: sphx-glr-script-out .. code-block:: none 'output (3).html' .. GENERATED FROM PYTHON SOURCE LINES 330-331 Or, if working inside a notebook, the output can be displayed directly by simply printing the result object: .. GENERATED FROM PYTHON SOURCE LINES 331-334 .. code-block:: default result .. raw:: html
Model Evaluation Suite


.. GENERATED FROM PYTHON SOURCE LINES 335-341 We can see that our model does not perform well, as can be seen in the "Class Performance" check under the "Didn't Pass" section of the suite results. This is because the model was trained on a different dataset, and the model was not trained to detect tomatoes. Moreover, we can see that lowering the IoU threshold could have fixed this a bit (as can be seen in the "Mean Average Precision Report" Check), but would still keep the overall precision low. Moreover, under the "Passed" section, we can see that our drift checks have passed, which means that the distribution of the predictions on the training and test data is similar, and the issue is not there but in the model itself. .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 31.159 seconds) .. _sphx_glr_download_vision_auto_tutorials_quickstarts_plot_detection_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_detection_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_detection_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_