The fairGATE package provides a complete pipeline for training and evaluating a Gated Neural Network (GNN) designed to mitigate demographic bias in predictive modelling. The package implements a fairness-aware GNN that uses a custom loss function to enforce the Equalized Odds fairness criterion by minimising the variance in True Positive and False Positive Rates across subgroups.
This vignette demonstrates the full workflow using the GENDEP dataset to predict antidepressant response, focusing on fairness across gender subgroups.
The package is built around a logical sequence of functions:
prepare_data(): Cleans and prepares the input
data.train_gnn(): Trains the Gated Neural Network.analyse_gnn_results(): Conducts performance and gate
analysis.analyse_experts(): Performs analysis of expert
specialisation.plot_sankey(): Visualises the model’s patient routing
behaviour.First, we load the necessary libraries and the dataset.
We use a small in-package sample of the UCI Adult dataset. Outcome is **income** (1 = >50K, 0 = ≤50K); protected attribute is **sex**.
library(fairGATE)
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
library(readxl)
# Loading the UCI Adult Dataset
data("adult_ready_small", package = "fairGATE")
adult_data <- adult_ready_small
adult <- adult_data %>%
mutate(
across(where(is.character), ~ trimws(.x)),
income = as.integer(income)
)We use prepare_data() to process the raw data for the
Male/Female analysis.
# Dropping unwanted cols (i.e. numeric cols and those with high multicolinearity)
cols_to_drop <- c("subjectid", "Row.names")
# Ensure to perform other preprocessing steps such as one-hot endoing etc
# Fully prepared data goes here
prepared <- fairGATE::prepare_data(
data = adult,
outcome_var = "income",
group_var = "sex",
cols_to_remove= cols_to_drop
)
#> Auto-generated group_mappings: Female, Male# Train a small Gated Neural Network
trained_model <- fairGATE::train_gnn(
prepared_data = prepared,
run_tuning = FALSE, # skip tuning for speed
best_params = list(
lr = 0.01,
hidden_dim = 16,
dropout_rate = 0.1,
lambda = 0.0,
temperature = 1.0
),
num_repeats = 2, # very short repeated split
epochs = 20, # fast CRAN-safe runtime
verbose = FALSE
)With the results loaded, we run analyse_gnn_results() to
generate all the standard performance plots and gate analyses.
# Run basic analysis
basic_analyses <- analyse_gnn_results(
gnn_results = trained_model,
prepared_data = prepared
)
# --- View all plots from the basic analysis ---
cat("## ROC Curve\n")
#> ## ROC Curve
print(basic_analyses$roc_plot)
cat("\n## Gate Weight Distribution\n")
#>
#> ## Gate Weight Distribution
print(basic_analyses$gate_density_plot)
cat("\n## Gate Entropy Distribution\n")
#>
#> ## Gate Entropy Distribution
print(basic_analyses$entropy_density_plot)Now, we use analyse_experts() to investigate how the
different expert networks have specialised their learning.
The analyse_experts() function summarises expert weights
per subgroup, compares mean importance across groups, and produces
difference or multi-group plots.
exp_res <- analyse_experts(
gnn_results = trained_model, # from train_gnn()
prepared_data = prepared, # from prepare_data()
top_n_features = 15, # number of top features to visualise
verbose = TRUE
)
#> Starting Expert Feature Weight Analysis...
#> Using group_mappings retrieved from prepared_data attributes.
#> Feature importance summaries computed.
# View the main objects returned
names(exp_res)
#> [1] "all_weights" "means_by_group_wide" "pairwise_differences"
#> [4] "difference_plot" "multi_group_plot" "top_features_multi"
#> [1] "all_weights" "means_by_group_wide"
#> [3] "pairwise_differences" "difference_plot"
#> [5] "multi_group_plot" "top_features_multi"
# View first few feature importances
head(exp_res$means_by_group_wide)
#> # A tibble: 6 × 3
#> feature Female Male
#> <chr> <dbl> <dbl>
#> 1 age 0.112 0.0860
#> 2 capital_gain 0.0967 0.103
#> 3 capital_loss 0.0808 0.0710
#> 4 education_10th 0.0601 0.0770
#> 5 education_11th 0.117 0.0881
#> 6 education_12th 0.0676 0.0980
# Example: view one pairwise difference table
names(exp_res$pairwise_differences)
#> [1] "Male_vs_Female"
#> [1] "Female_vs_Male"
head(exp_res$pairwise_differences[[1]])
#> # A tibble: 6 × 4
#> feature Female Male difference
#> <chr> <dbl> <dbl> <dbl>
#> 1 education_7th.8th 0.0815 0.130 0.0482
#> 2 education_Assoc.acdm 0.0886 0.0462 -0.0425
#> 3 race_Amer.Indian.Eskimo 0.0572 0.0975 0.0403
#> 4 native_country_China 0.0589 0.0988 0.0399
#> 5 education_9th 0.0633 0.101 0.0380
#> 6 education_num 0.0680 0.103 0.0352
# Visualise feature specialisation
if (!is.null(exp_res$difference_plot)) print(exp_res$difference_plot)We can use plot_sankey() to create the key visualisation
from the research paper, showing how patients are routed through the
model.
The Sankey shows how patients flow from actual subgroup ->
assigned expert. It auto-derives subgroup labels from
prepared and subject IDs from
trained_model.
# Generate and print the Sankey plot
p <- plot_sankey(
prepared_data = prepared, # from prepare_data()
gnn_results = trained_model, # from train_gnn()
expert_results = exp_res, # from analyse_experts()
verbose = TRUE
)
#> Sankey: deriving assigned expert per subject...
#> Sankey: using 2-axis (group -> expert).
print(p)The export_f360() function writes a clean plug-and-play CSV for use in IBM Fairness 360, containing columns for subject IDs, true labels, predicted probabilities, and the sensitive attribute. It can also include gate probabilities if desired.