Introduction to fairGATE

1. Introduction

The fairGATE package provides a complete pipeline for training and evaluating a Gated Neural Network (GNN) designed to mitigate demographic bias in predictive modelling. The package implements a fairness-aware GNN that uses a custom loss function to enforce the Equalized Odds fairness criterion by minimising the variance in True Positive and False Positive Rates across subgroups.

This vignette demonstrates the full workflow using the GENDEP dataset to predict antidepressant response, focusing on fairness across gender subgroups.

2. The fairGATE Workflow

The package is built around a logical sequence of functions:

2.1. Loading Data

First, we load the necessary libraries and the dataset.

We use a small in-package sample of the UCI Adult dataset. Outcome is **income** (1 = >50K, 0 = ≤50K); protected attribute is **sex**.

library(fairGATE)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(readxl)

# Loading the UCI Adult Dataset

data("adult_ready_small", package = "fairGATE")
adult_data <- adult_ready_small

adult <- adult_data %>%
  mutate(
    across(where(is.character), ~ trimws(.x)),
    income = as.integer(income)
  )

2.2. Step 1: Prepare the Data

We use prepare_data() to process the raw data for the Male/Female analysis.


# Dropping unwanted cols (i.e. numeric cols and those with high multicolinearity)
cols_to_drop <- c("subjectid", "Row.names")

#  Ensure to perform other preprocessing steps such as one-hot endoing etc

# Fully prepared data goes here

prepared <- fairGATE::prepare_data(
  data          = adult,
  outcome_var   = "income",
  group_var     = "sex",
  cols_to_remove= cols_to_drop
)
#> Auto-generated group_mappings: Female, Male

2.3. Train a small demo model

# Train a small Gated Neural Network
trained_model <- fairGATE::train_gnn(
  prepared_data = prepared,
  run_tuning    = FALSE,     # skip tuning for speed
  best_params   = list(
    lr = 0.01,
    hidden_dim = 16,
    dropout_rate = 0.1,
    lambda = 0.0,
    temperature = 1.0
  ),
  num_repeats   = 2,         # very short repeated split
  epochs        = 20,        # fast CRAN-safe runtime
  verbose       = FALSE
)

2.4. Step 3: Analyse Basic Performance and Gates

With the results loaded, we run analyse_gnn_results() to generate all the standard performance plots and gate analyses.


# Run basic analysis
basic_analyses <- analyse_gnn_results(
  gnn_results = trained_model,
  prepared_data = prepared
)

# --- View all plots from the basic analysis ---
cat("## ROC Curve\n")
#> ## ROC Curve
print(basic_analyses$roc_plot)


cat("\n## Calibration Plot\n")
#> 
#> ## Calibration Plot
print(basic_analyses$calibration_plot)


cat("\n## Gate Weight Distribution\n")
#> 
#> ## Gate Weight Distribution
print(basic_analyses$gate_density_plot)


cat("\n## Gate Entropy Distribution\n")
#> 
#> ## Gate Entropy Distribution
print(basic_analyses$entropy_density_plot)

2.5. Step 4: Analyse Expert Specialisation

Now, we use analyse_experts() to investigate how the different expert networks have specialised their learning.

The analyse_experts() function summarises expert weights per subgroup, compares mean importance across groups, and produces difference or multi-group plots.

exp_res <- analyse_experts(
  gnn_results     = trained_model,   # from train_gnn()
  prepared_data   = prepared,        # from prepare_data()
  top_n_features  = 15,              # number of top features to visualise
  verbose         = TRUE
)
#> Starting Expert Feature Weight Analysis...
#> Using group_mappings retrieved from prepared_data attributes.
#> Feature importance summaries computed.

# View the main objects returned
names(exp_res)
#> [1] "all_weights"          "means_by_group_wide"  "pairwise_differences"
#> [4] "difference_plot"      "multi_group_plot"     "top_features_multi"
#> [1] "all_weights"          "means_by_group_wide" 
#> [3] "pairwise_differences" "difference_plot"     
#> [5] "multi_group_plot"     "top_features_multi"

# View first few feature importances
head(exp_res$means_by_group_wide)
#> # A tibble: 6 × 3
#>   feature        Female   Male
#>   <chr>           <dbl>  <dbl>
#> 1 age            0.112  0.0860
#> 2 capital_gain   0.0967 0.103 
#> 3 capital_loss   0.0808 0.0710
#> 4 education_10th 0.0601 0.0770
#> 5 education_11th 0.117  0.0881
#> 6 education_12th 0.0676 0.0980

# Example: view one pairwise difference table
names(exp_res$pairwise_differences)
#> [1] "Male_vs_Female"
#> [1] "Female_vs_Male"
head(exp_res$pairwise_differences[[1]])
#> # A tibble: 6 × 4
#>   feature                 Female   Male difference
#>   <chr>                    <dbl>  <dbl>      <dbl>
#> 1 education_7th.8th       0.0815 0.130      0.0482
#> 2 education_Assoc.acdm    0.0886 0.0462    -0.0425
#> 3 race_Amer.Indian.Eskimo 0.0572 0.0975     0.0403
#> 4 native_country_China    0.0589 0.0988     0.0399
#> 5 education_9th           0.0633 0.101      0.0380
#> 6 education_num           0.0680 0.103      0.0352

# Visualise feature specialisation
if (!is.null(exp_res$difference_plot)) print(exp_res$difference_plot)

if (!is.null(exp_res$multi_group_plot)) print(exp_res$multi_group_plot)

2.6. Step 5: Visualise Patient Routing with a Sankey Plot

We can use plot_sankey() to create the key visualisation from the research paper, showing how patients are routed through the model.

The Sankey shows how patients flow from actual subgroup -> assigned expert. It auto-derives subgroup labels from prepared and subject IDs from trained_model.

# Generate and print the Sankey plot
p <- plot_sankey(
  prepared_data  = prepared,       # from prepare_data()
  gnn_results    = trained_model,  # from train_gnn()
  expert_results = exp_res,        # from analyse_experts()
  verbose        = TRUE
)
#> Sankey: deriving assigned expert per subject...
#> Sankey: using 2-axis (group -> expert).

print(p)

2.7. (Optional) Export data for Fairness 360 or external fairness analysis

The export_f360() function writes a clean plug-and-play CSV for use in IBM Fairness 360, containing columns for subject IDs, true labels, predicted probabilities, and the sensitive attribute. It can also include gate probabilities if desired.

export_f360_csv(
  gnn_results       = trained_model,   # from train_gnn()
  prepared_data     = prepared,        # from prepare_data()
  path              = "outputs/fairness360_input.csv",
  include_gate_cols = TRUE,            # include expert routing probabilities
  threshold         = 0.5,             # classification threshold for binary outcome
  verbose           = TRUE
)