Whole Dataset Drift#

This notebooks provides an overview for using and understanding the whole dataset drift check.

Structure:

What is a dataset drift?#

A whole dataset drift, or a multivariate dataset drift, occurs when the statistical properties of our input feature change, denoted by a change in the distribution P(X).

Causes of data drift include:

  • Upstream process changes, such as a sensor being replaced that changes the units of measurement from inches to centimeters.

  • Data quality issues, such as a broken sensor always reading 0.

  • Natural drift in the data, such as mean temperature changing with the seasons.

  • Change in relation between features, or covariate shift.

The difference between a feature drift (or univariate dataset drift) and a multivariate drift is that in the latter the data drift occures in more that one feature.

In the context of machine learning, drift between the training set and the test means that the model was trained on data that is different from the current test data, thus it will probably make more mistakes predicting the target variable.

How deepchecks detects dataset drift#

There are many methods to detect feature drift. Some of them are statistical methods that aim to measure difference between distribution of 2 given sets. This methods are more suited to univariate distributions and are primarily used to detect drift between 2 subsets of a single feature.

Measuring a multivariate data drift is a bit more challenging. In the whole dataset drift check, the multivariate drift is measured by training a classifier that detects which samples come from a known distribution and defines the drift by the accuracy of this classifier.

Practically, the check concatanates the train and the test sets, and assigns label 0 to samples that come from the training set, and 1 to those who are from the test set. Then, we train a binary classifer of type Histogram-based Gradient Boosting Classification Tree, and measure the drift score from the AUC score of this classifier.

Loading the Data#

The dataset is the adult dataset which can be downloaded from the UCI machine learning repository.

Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science.

from urllib.request import urlopen

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder

from deepchecks.tabular import Dataset
from deepchecks.tabular.datasets.classification import adult

Create Dataset#

label_name = 'income'
train_ds, test_ds = adult.load_data()
encoder = LabelEncoder()
train_ds.data[label_name] = encoder.fit_transform(train_ds.data[label_name])
test_ds.data[label_name] = encoder.transform(test_ds.data[label_name])
train_ds.label_name

Out:

'income'

Run the check#

from deepchecks.tabular.checks import WholeDatasetDrift

check = WholeDatasetDrift()
check.run(train_dataset=train_ds, test_dataset=test_ds)

Out:

Calculating permutation feature importance. Expected to finish in 3 seconds

Whole Dataset Drift

Calculate drift between the entire train and test datasets using a model trained to distinguish between them.

Additional Outputs

Nothing to display



We can see that there is almost no drift found between the train and the test set of the raw adult dataset. In addition to the drift score the check displays the top features that contibuted to the data drift.

Introduce drift to dataset#

Now, let’s try to add a manual data drift to the data by sampling a biased portion of the training data

sample_size = 10000
random_seed = 0
train_drifted_df = pd.concat([train_ds.data.sample(min(sample_size, train_ds.n_samples) - 5000, random_state=random_seed),
                             train_ds.data[train_ds.data['sex'] == ' Female'].sample(5000, random_state=random_seed)])
test_drifted_df = test_ds.data.sample(min(sample_size, test_ds.n_samples), random_state=random_seed)

train_drifted_ds = Dataset(train_drifted_df, label=label_name, cat_features=train_ds.cat_features)
test_drifted_ds = Dataset(test_drifted_df, label=label_name, cat_features=test_ds.cat_features)
check = WholeDatasetDrift()
check.run(train_dataset=train_drifted_ds, test_dataset=test_drifted_ds)

Out:

Calculating permutation feature importance. Expected to finish in 4 seconds

Whole Dataset Drift

Calculate drift between the entire train and test datasets using a model trained to distinguish between them.

Additional Outputs
The shown features are the features that are most important for the domain classifier - the domain_classifier trained to distinguish between the train and test datasets.
The percents of explained dataset difference are the importance values for the feature calculated using `permutation_importance`.

Main features contributing to drift

* showing only the top 3 columns, you can change it using n_top_columns param


As expected, the check detects a multivariate drift between the train and the test sets. It also displays the sex feature’s distribution - the feature that contributed the most to that drift. This is reasonable since the sampling was biased based on that feature.

Define a condition#

Now, we define a condition that enforce the whole dataset drift score must be below 0.1. A condition is deepchecks’ way to validate model and data quality, and let you know if anything goes wrong.

check = WholeDatasetDrift()
check.add_condition_overall_drift_value_not_greater_than(0.1)
check.run(train_dataset=train_drifted_ds, test_dataset=test_drifted_ds)

Out:

Calculating permutation feature importance. Expected to finish in 3 seconds

Whole Dataset Drift

Calculate drift between the entire train and test datasets using a model trained to distinguish between them.

Conditions Summary
Status Condition More Info
Drift value is not greater than 0.1 Found drift value of: 0.35, corresponding to a domain classifier AUC of: 0.67
Additional Outputs
The shown features are the features that are most important for the domain classifier - the domain_classifier trained to distinguish between the train and test datasets.
The percents of explained dataset difference are the importance values for the feature calculated using `permutation_importance`.

Main features contributing to drift

* showing only the top 3 columns, you can change it using n_top_columns param


As we see, our condition successfully detects the drift score is above the defined threshold.

Total running time of the script: ( 0 minutes 8.392 seconds)

Gallery generated by Sphinx-Gallery