Compare learners with respect to to one or multiple metrics. Metrics can but be but are not limited to fairness metrics.
Arguments
- object
(mlr3::PredictionClassif | mlr3::BenchmarkResult | mlr3::ResampleResult)
The object to create a plot for.If provided a (mlr3::PredictionClassif). Then the visualization will compare the fairness metrics among the binary level from protected field through bar plots.
If provided a (mlr3::ResampleResult). Then the visualization will generate the boxplots for fairness metrics, and compare them among the binary level from protected field.
If provided a (mlr3::BenchmarkResult). Then the visualization will generate the boxplots for fairness metrics, and compare them among both the binary level from protected field and the models implemented.
- ...
The arguments to be passed to methods, such as:
fairness_measures
(list of mlr3::Measure)
The fairness measures that will evaluated on object, could be single mlr3::Measure or list of mlr3::Measures. Default measure set to bemsr("fairness.acc")
.task
(mlr3::TaskClassif)
The data task that contains the protected column, only required when object is (mlr3::PredictionClassif).
Protected Attributes
The protected attribute is specified as a col_role
in the corresponding mlr3::Task()
:<Task>$col_roles$pta = "name_of_attribute"
This also allows specifying more than one protected attribute,
in which case fairness will be considered on the level of intersecting groups defined by all columns
selected as a predicted attribute.
Examples
library("mlr3")
library("mlr3learners")
# Setup the Fairness Measures and tasks
task = tsk("adult_train")$filter(1:500)
learner = lrn("classif.ranger", predict_type = "prob")
learner$train(task)
predictions = learner$predict(task)
design = benchmark_grid(
tasks = task,
learners = lrns(c("classif.ranger", "classif.rpart"),
predict_type = "prob", predict_sets = c("train", "test")),
resamplings = rsmps("cv", folds = 3)
)
bmr = benchmark(design)
fairness_measure = msr("fairness.tpr")
fairness_measures = msrs(c("fairness.tpr", "fairness.fnr", "fairness.acc"))
# Predictions
compare_metrics(predictions, fairness_measure, task)
compare_metrics(predictions, fairness_measures, task)
# BenchmarkResult and ResamplingResult
compare_metrics(bmr, fairness_measure)
compare_metrics(bmr, fairness_measures)