Skip to contents

Compare learners with respect to to one or multiple metrics. Metrics can but be but are not limited to fairness metrics.

Usage

compare_metrics(object, ...)

Arguments

object

(PredictionClassif | BenchmarkResult | ResampleResult)
The object to create a plot for.

  • If provided a (PredictionClassif). Then the visualization will compare the fairness metrics among the binary level from protected field through bar plots.

  • If provided a (ResampleResult). Then the visualization will generate the boxplots for fairness metrics, and compare them among the binary level from protected field.

  • If provided a (BenchmarkResult). Then the visualization will generate the boxplots for fairness metrics, and compare them among both the binary level from protected field and the models implemented.

...

The arguments to be passed to methods, such as:

  • fairness_measures (list of Measure)
    The fairness measures that will evaluated on object, could be single Measure or list of Measures. Default measure set to be msr("fairness.acc").

  • task (TaskClassif)
    The data task that contains the protected column, only required when object is (PredictionClassif).

Value

A 'ggplot2' object.

Protected Attributes

The protected attribute is specified as a col_role in the corresponding Task():
<Task>$col_roles$pta = "name_of_attribute"
This also allows specifying more than one protected attribute, in which case fairness will be considered on the level of intersecting groups defined by all columns selected as a predicted attribute.

Examples

library("mlr3")
library("mlr3learners")

# Setup the Fairness Measures and tasks
task = tsk("adult_train")$filter(1:500)
learner = lrn("classif.ranger", predict_type = "prob")
learner$train(task)
predictions = learner$predict(task)
design = benchmark_grid(
  tasks = task,
  learners = lrns(c("classif.ranger", "classif.rpart"),
    predict_type = "prob", predict_sets = c("train", "test")),
  resamplings = rsmps("cv", folds = 3)
)

bmr = benchmark(design)
fairness_measure = msr("fairness.tpr")
fairness_measures = msrs(c("fairness.tpr", "fairness.fnr", "fairness.acc"))

# Predictions
compare_metrics(predictions, fairness_measure, task)

compare_metrics(predictions, fairness_measures, task)


# BenchmarkResult and ResamplingResult
compare_metrics(bmr, fairness_measure)

compare_metrics(bmr, fairness_measures)