.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "vision/auto_checks/model_evaluation/plot_simple_model_comparison.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_vision_auto_checks_model_evaluation_plot_simple_model_comparison.py: .. _vision__simple_model_comparison: Simple Model Comparison *********************** This notebooks provides an overview for using and understanding simple model comparison check. **Structure:** * `What Is the Purpose of the Check? <#what-is-the-purpose-of-the-check>`__ * `Generate Dataset <#generate-dataset>`__ * `Run the check <#run-the-check>`__ What Is the Purpose of the Check? ================================= This check compares your current model to a "simple model", which is a model designed to produce the best performance achievable using very simple rules, such as "always predict the most common class". The simple model is used as a **baseline** model; If your model achieves less or similar score to the simple model, this is an indicator of a possible problem with the model (e.g. it wasn't trained properly). Using the parameter ``strategy``, you can select the simple model used in the check: ================ =================================== Strategy Description ================ =================================== prior (default) The probability vector always contains the empirical class prior distribution (i.e. the class distribution observed in the training set). most_frequent The most frequent prediction is predicted. The probability vector is 1 for the most frequent prediction and 0 for the other predictions. stratified The predictions are generated by sampling one-hot vectors from a multinomial distribution parametrized by the empirical class prior probabilities. uniform Generates predictions uniformly at random from the list of unique classes observed in y, i.e. each class has equal probability. ================ =================================== Similiar to the :ref:`tabular__simple_model_comparison` check, there is no simple model which is more "correct" to use, each gives a different baseline to compare to, and you may experiment with the different types and see how it performs on your data. This checks applies only to classification datasets. .. GENERATED FROM PYTHON SOURCE LINES 47-55 Generate Dataset ---------------- .. note:: In this example, we use the pytorch version of the mnist dataset and model. In order to run this example using tensorflow, please change the import statements to:: from deepchecks.vision.datasets.classification import mnist_tensorflow as mnist .. GENERATED FROM PYTHON SOURCE LINES 55-59 .. code-block:: default from deepchecks.vision.checks import SimpleModelComparison from deepchecks.vision.datasets.classification import mnist_torch as mnist .. GENERATED FROM PYTHON SOURCE LINES 60-65 .. code-block:: default train_ds = mnist.load_dataset(train=True, object_type='VisionData') test_ds = mnist.load_dataset(train=False, object_type='VisionData') .. rst-class:: sphx-glr-script-out .. code-block:: none You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. .. GENERATED FROM PYTHON SOURCE LINES 66-71 Run the check ------------- We will run the check with the prior model type. The check will use the default classification metrics - precision and recall. This can be overridden by providing an alternative scorer using the ``scorers``` parameter. .. GENERATED FROM PYTHON SOURCE LINES 71-75 .. code-block:: default check = SimpleModelComparison(strategy='stratified') result = check.run(train_ds, test_ds) .. rst-class:: sphx-glr-script-out .. code-block:: none Processing Train Batches: | | 0/1 [Time: 00:00] Processing Train Batches: |█████| 1/1 [Time: 00:01] Processing Train Batches: |█████| 1/1 [Time: 00:01] Processing Test Batches: | | 0/1 [Time: 00:00] Processing Test Batches: |█████| 1/1 [Time: 00:05] Processing Test Batches: |█████| 1/1 [Time: 00:05] Computing Check: | | 0/1 [Time: 00:00] Computing Check: |█████| 1/1 [Time: 00:00] Computing Check: |█████| 1/1 [Time: 00:00] .. GENERATED FROM PYTHON SOURCE LINES 76-78 .. code-block:: default result.show() .. raw:: html
Simple Model Comparison


.. GENERATED FROM PYTHON SOURCE LINES 79-80 To display the results in an IDE like PyCharm, you can use the following code: .. GENERATED FROM PYTHON SOURCE LINES 80-82 .. code-block:: default # result.show_in_window() .. GENERATED FROM PYTHON SOURCE LINES 83-84 The result will be displayed in a new window. .. GENERATED FROM PYTHON SOURCE LINES 86-95 Observe the check's output -------------------------- We can see in the results that the check calculates the score for each class in the dataset, and compares the scores between our model and the simple model. In addition to the graphic output, the check also returns a value which includes all of the information that is needed for defining the conditions for validation. The value is a dataframe that contains the metrics' values for each class and dataset: .. GENERATED FROM PYTHON SOURCE LINES 95-98 .. code-block:: default result.value.sort_values(by=['Class', 'Metric']).head(10) .. raw:: html
Model Metric Class Class Name Number of samples Value
8 Simple Model F1 0 0 980 0.088576
10 Perfect Model F1 0 0 980 1.000000
20 Given Model F1 0 0 980 0.985316
0 Simple Model F1 1 1 1135 0.127734
11 Perfect Model F1 1 1 1135 1.000000
28 Given Model F1 1 1 1135 0.956600
9 Simple Model F1 2 2 1032 0.087613
12 Perfect Model F1 2 2 1032 1.000000
23 Given Model F1 2 2 1032 0.978334
2 Simple Model F1 3 3 1010 0.103704


.. GENERATED FROM PYTHON SOURCE LINES 99-116 Define a condition ================== We can define on our check a condition that will validate our model is better than the simple model by a given margin called gain. For classification we check the gain for each class separately and if there is a class that doesn't pass the defined gain the condition will fail. The performance gain is the percent of the improved performance out of the "remaining" unattained performance. Its purpose is to reflect the significance of the said improvement. Take for example for a metric between 0 and 1. A change of only 0.03 that takes us from 0.95 to 0.98 is highly significant (especially in an imbalance scenario), but improving from 0.1 to 0.13 is not a great achievement. The gain is calculated as: :math:`gain = \frac{\text{model score} - \text{simple score}} {\text{perfect score} - \text{simple score}}` Let's add a condition to the check and see what happens when it fails: .. GENERATED FROM PYTHON SOURCE LINES 116-122 .. code-block:: default check = SimpleModelComparison(strategy='stratified') check.add_condition_gain_greater_than(min_allowed_gain=0.99) result = check.run(train_ds, test_ds) result.show() .. rst-class:: sphx-glr-script-out .. code-block:: none Processing Train Batches: | | 0/1 [Time: 00:00] Processing Train Batches: |█████| 1/1 [Time: 00:01] Processing Train Batches: |█████| 1/1 [Time: 00:01] Processing Test Batches: | | 0/1 [Time: 00:00] Processing Test Batches: |█████| 1/1 [Time: 00:01] Processing Test Batches: |█████| 1/1 [Time: 00:01] Computing Check: | | 0/1 [Time: 00:00] Computing Check: |█████| 1/1 [Time: 00:00] Computing Check: |█████| 1/1 [Time: 00:00] .. raw:: html
Simple Model Comparison


.. GENERATED FROM PYTHON SOURCE LINES 123-125 We detected that for several classes our gain did not passed the target gain we defined, therefore it failed. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 10.904 seconds) .. _sphx_glr_download_vision_auto_checks_model_evaluation_plot_simple_model_comparison.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_simple_model_comparison.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_simple_model_comparison.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_