Custom Check Templates#
This page supplies templates for the different types of custom checks that you can create using the deepchecks package. For more information on custom checks, please see the Custom Check Guide.
Templates:
Single Dataset Check#
Check type for cases when running on a single dataset and optional model, for example integrity checks. When in suite if 2 datasets are supplied it will run on both independently.
from deepchecks.core import CheckResult, ConditionCategory, ConditionResult, DatasetKind
from deepchecks.vision import SingleDatasetCheck, Context, VisionData, Batch
class SingleDatasetCustomCheck(SingleDatasetCheck):
"""Description of the check. The name of the check will be the class name split by upper case letters."""
# OPTIONAL: we can add different properties in the init
def __init__(self, prop_a: str, prop_b: str, **kwargs):
super().__init__(**kwargs)
self.prop_a = prop_a
self.prop_b = prop_b
def initialize_run(self, context: Context, dataset_kind: DatasetKind):
# Initialize cache
self.cache = {}
# OPTIONAL: add validations on inputs and properties like prop_a and prop_b
def update(self, context: Context, batch: BatchWrapper, dataset_kind: DatasetKind):
# Get the VisionData by its type (train/test)
dataset: VisionData = context.get_data_by_kind(dataset_kind)
# Take from the batch the data I need it and save it on the cache
batch_data_dict = some_calc_on_batch(batch, dataset)
# Save the data on the cache
self.cache.update(batch_data_dict)
def compute(self, context: Context, dataset_kind: DatasetKind) -> CheckResult:
# LOGIC HERE
failing_samples = some_calc_on_cache(self.cache, self.prop_a, self.prop_b)
# Define result value: Adding any info that we might want to know later
result = {
'ratio': len(failing_samples) / len(self.cache),
'indices': failing_samples.keys()
}
# Define result display: list of either plotly-figure/dataframe/html
display = None
return CheckResult(result, display=display)
# OPTIONAL: add condition to check
def add_condition_ratio_less_than(self, threshold: float = 0.01):
# Define condition function: the function accepts as input the result value we defined in the run_logic
def condition(result):
ratio = result['ratio']
category = ConditionCategory.PASS if ratio < threshold else ConditionCategory.FAIL
message = f'Found X ratio of {ratio}'
return ConditionResult(category, message)
# Define the name of the condition
name = f'Custom check ratio is less than {threshold}'
# Now add it on the class instance
return self.add_condition(name, condition)
Train Test Check#
Check type for cases when running on two datasets and optional model, for example drift checks.
from deepchecks.core import CheckResult, ConditionCategory, ConditionResult, DatasetKind
from deepchecks.vision import TrainTestCheck, Context, VisionData, Batch
class SingleDatasetCustomCheck(TrainTestCheck):
"""Description of the check. The name of the check will be the class name split by upper case letters."""
# OPTIONAL: we can add different properties in the init
def __init__(self, prop_a: str, prop_b: str, **kwargs):
super().__init__(**kwargs)
self.prop_a = prop_a
self.prop_b = prop_b
def initialize_run(self, context: Context):
# Initialize cache
self.cache = {
DatasetKind.TRAIN: {},
DatasetKind.TEST: {}
}
# OPTIONAL: add validations on inputs and properties like prop_a and prop_b
def update(self, context: Context, batch: BatchWrapper, dataset_kind: DatasetKind):
# Get the VisionData by its type (train/test)
dataset: VisionData = context.get_data_by_kind(dataset_kind)
# Take from the batch the data I need it and save it on the cache
batch_data_dict = some_calc_on_batch(batch, dataset)
# Save the data on the cache
self.cache[dataset_kind].update(batch_data_dict)
def compute(self, context: Context) -> CheckResult:
# Get the VisionData
train_vision_data: VisionData = context.train
test_vision_data: VisionData = context.test
# LOGIC HERE
failing_samples = some_calc_on_cache(self.cache, self.prop_a, self.prop_b)
# Define result value: Adding any info that we might want to know later
result = {
'ratio': len(failing_samples) / len(self.cache),
'indices': failing_samples.keys()
}
# Define result display: list of either plotly-figure/dataframe/html
display = None
return CheckResult(result, display=display)
# OPTIONAL: add condition to check
def add_condition_ratio_less_than(self, threshold: float = 0.01):
# Define condition function: the function accepts as input the result value we defined in the run_logic
def condition(result):
ratio = result['ratio']
category = ConditionCategory.PASS if ratio < threshold else ConditionCategory.FAIL
message = f'Found X ratio of {ratio}'
return ConditionResult(category, message)
# Define the name of the condition
name = f'Custom check ratio is less than {threshold}'
# Now add it on the class instance
return self.add_condition(name, condition)
Model Only Check#
Check type for cases when running only on a model, for example model parameters check.
from deepchecks.core import CheckResult, ConditionCategory, ConditionResult
from deepchecks.vision import ModelOnlyCheck, Context
class ModelOnlyCustomCheck(ModelOnlyCheck):
"""Description of the check. The name of the check will be the class name split by upper case letters."""
# OPTIONAL: we can add different properties in the init
def __init__(self, prop_a: str, prop_b: str, **kwargs):
super().__init__(**kwargs)
self.prop_a = prop_a
self.prop_b = prop_b
def compute(self, context: Context) -> CheckResult:
# Get the model
model = context.model
# LOGIC HERE - possible to add validations on inputs and properties like prop_a and prop_b
some_score = some_calc_fn(model, self.prop_a, self.prop_b)
# Define result value: Adding any info that we might want to know later
result = some_score
# Define result display: list of either plotly-figure/dataframe/html, or Nothing if we have no display
display = None
return CheckResult(result, display=display)
# OPTIONAL: add condition to check
def add_condition_score_more_than(self, threshold: float = 1):
# Define condition function: the function accepts as input the result value we defined in the run_logic
def condition(result):
category = ConditionCategory.PASS if result > 1 else ConditionCategory.FAIL
message = f'Found X score of {result}'
return ConditionResult(category, message)
# Define the name of the condition
name = f'Custom check score is more than {threshold}'
# Now add it on the class instance
return self.add_condition(name, condition)