Variable importance (VImp), variable interaction measures (VInt) and
partial dependence plots (PDPs) are important summaries in the
interpretation of statistical and machine learning models. In this
vignette we describe new visualization techniques for exploring these
model summaries. We construct heatmap and graph-based displays showing
variable importance and interaction jointly, which are carefully
designed to highlight important aspects of the fit. We describe a new
matrix-type layout showing all single and bivariate partial dependence
plots, and an alternative layout based on graph Eulerians focusing on
key subsets. Our new visualisations are model-agnostic and are
applicable to regression and classification supervised learning
settings. They enhance interpretation even in situations where the
number of variables is large and the interaction structure complex. Our
R package vivid
(variable importance and variable
interaction displays) provides an implementation. When referring to VImp
and VInt together, we use the shorthand VIVI. For more information
related to visualising variable importance and interactions in machine
learning models see our published work1.
Some of the plots used by vivid
are built upon the
zenplots
package which requires the graph
package from BioConductor. To install the graph
and
zenplots
packages use:
if (!requireNamespace("graph", quietly = TRUE)){
install.packages("BiocManager")
BiocManager::install("graph")
}
install.packages("zenplots")
Now we can install vivid
by using:
install.packages("vivid")
Alternatively you can install the latest development version of the package in R with the commands:
if(!require(remotes)) install.packages('remotes')
remotes::install_github('AlanInglis/vividPackage')
We then load the required packages. vivid
to create the
visualizations and some other packages to create various model fits.
library(vivid) # for visualisations
library(randomForest) # for model fit
library(ranger) # for model fit
The data used in the following examples is simulated from the Friedman benchmark problem2. This benchmark problem is commonly used for testing purposes. The output is created according to the equation:
For the following examples we set the number of features to equal 9
and the number of samples is set to 350 and fit a
randomForest
random forest model with \(y\) as the response. As the features \(x_1\) to \(x_5\) are the only variables in the model,
therefore \(x_6\) to \(x_{9}\) are noise variables. As can be seen
by the above equation, the only interaction is between \(x_1\) and \(x_2\)
Create the data:
set.seed(101)
<- function(noFeatures = 10,
genFriedman noSamples = 100,
sigma = 1
) {# Set Values
<- noSamples # no of rows
n <- noFeatures # no of variables
p <- rnorm(n, sd = sigma)
e
# Create matrix of values
<- matrix(runif(n * p, 0, 1), nrow = n) # Create matrix
xValues colnames(xValues) <- paste0("x", 1:p) # Name columns
<- data.frame(xValues) # Create dataframe
df
# Equation:
# y = 10sin(πx1x2) + 20(x3−0.5)^2 + 10x4 + 5x5 + ε
<- (10 * sin(pi * df$x1 * df$x2) + 20 * (df$x3 - 0.5)^2 + 10 * df$x4 + 5 * df$x5 + e)
y # Adding y to df
$y <- y
df
df
}
<- genFriedman(noFeatures = 9, noSamples = 350, sigma = 1) myData
Here we fit the model. We create a random forest fit from the
randomForest
package.
set.seed(1701)
<- randomForest(y ~ ., data = myData) rf
To utilize vivid
, the initial step involves computing
variable importance and interactions for a fitted model. The vivi
function performs this calculation, producing a square, symmetrical
matrix that contains variable importance on the diagonal and variable
interactions on the off-diagonal. To calculate the pair-wise interaction
strength interactions Friedman’s model agnostic, unnormalized \(H\)-Statistic3 is used. The
unnormalized version of the \(H\)-statistic was chosen to have a more
direct comparison of interaction effects across pairs of variables and
the results of \(H\) are on the scale
of the response (for regression). For the importance, either a selected
embedded importance measure can be used (as seen in section 4) or an
agnostic permutation method4 can be selected.
The vivi function requires three inputs: a fitted machine learning
model, a data frame used in the model’s training, and the name of the
response variable for the fit. The resulting matrix will have importance
and interaction values for all variables in the data frame, excluding
the response variable. By default, if no embedded variable importance
method is available or selected, an agnostic permutation method is
applied. For clarity, this is shown in the
importanceType = 'agnostic'
argument below. For an example
of using embedded methods, see Section 4.
Any variables that are not used by the supplied model will have their
importance and interaction values set to zero. While the
viviHeatmap
and viviNetwork
visualization
functions (seen below) are tailored for displaying the results of vivi
calculations, they can also work with any square matrix that has
identical row and column names. (Note, the symmetry assumption is not
required for viviHeatmap
and viviNetwork
uses
interaction values from the lower-triangular part of the matrix
only.)
This function works with multiple model fits and results in a matrix
which can be supplied to the plotting functions. The predict function
argument uses condvis2::CVpredict
by default, which works
for many fit classes. To see a description of all function arguments
use: ?vivid::vivi()
set.seed(1701)
<- vivi(fit = rf,
viviRf data = myData,
response = "y",
gridSize = 50,
importanceType = "agnostic",
nmax = 500,
reorder = TRUE,
predictFun = NULL,
numPerm = 4,
showVimpError = FALSE)
NOTE: If viewing this vignette from the vivid CRAN page, then some images may not format correctly. It is recommended to view this vignette via the following link: https://alaninglis.github.io/vivid/articles/vividVignette.html
The viviHeatmap
function generates a heatmap that
displays variable importance and interactions, with importance values on
the diagonal and interaction values on the off-diagonal. The function
only requires a vivid
matrix as input, which does not need
to be symmetrical. Additionally, color palettes can be specified for
both importance and interactions via the impPal
and
intPal
arguments. By default, we have opted for single-hue,
color-blind friendly sequential color palettes developed by Zeileis et
al5.
These palettes represent low and high VIVI values with low and high
luminance colors, respectively, which can aid in highlighting pertinent
values.
The impLims
and intLims
arguments determine
the range of importance and interaction values that will be assigned
colors. If these arguments are not provided, the default values will be
calculated based on the minimum and maximum VIVI values in the
vivid
matrix. If any importance or interaction values fall
outside of the specified limits, they will be squished to the closest
limit. For brevity, only the required vivid
matrix input is
shown in the following code. To see a description of all the function
arguments, see ?vivid::viviheatmap()
viviHeatmap(mat = viviRf)
With viviNetwork
, a network graph is produced to
visualize both importance and interactions. Similar to
viviHeatmap
, this function only requires a
vivid
matrix as input and uses visual elements, such as
size and color, to depict the magnitude of importance and interaction
values. The graph displays each variable as a node, where its size and
color reflect its importance (larger and darker nodes indicate higher
importance). Pairwise interactions are displayed through connecting
edges, where thicker and darker edges indicate higher interaction
values.
To begin we show the network using default settings.
viviNetwork(mat = viviRf)
We can also filter out any interactions below a set value using the
intThreshold
argument. This can be useful when the number
of variables included in the model is large or just to highlight the
strongest interactions. By default, unconnected nodes are displayed,
however, they can be removed by setting the argument
removeNode = T
.
viviNetwork(mat = viviRf, intThreshold = 0.12, removeNode = FALSE)
viviNetwork(mat = viviRf, intThreshold = 0.12, removeNode = TRUE)
The network plot offers multiple customization possibilities when it
comes to displaying the network style plot through use of the
layout
argument. The default layout is a circle but the
argument accepts any igraph
layout function or a numeric
matrix with two columns, one row per node.
viviNetwork(mat = viviRf,
layout = cbind(c(1,1,1,1,2,2,2,2,2), c(1,2,4,5,1,2,3,4,5)))
Finally, for the network plot to highlight any relationships in the
model fit, we can cluster variables together using the
cluster
argument. This argument can either accept a vector
of cluster memberships for nodes or an igraph
package
clustering function. In the following example, we manually select
variables with VIVI values in the top 20%. This selection allows us to
focus only on the variables with the most impact on the response. The
variables that remain are \(x1\) to
\(x5\). We then perform a hierarchical
clustering treating variable interactions as similarities, with the goal
of grouping together high-interaction variables. Here we manually select
the number of groups we want to show via the cutree
function (which cuts clustered data into a desired number of groups).
Finally we rearrange the layout using igraph
. Here,
igraph::layout_as_star
places the first variable (deemed
most relevant using the VIVI seriation process) at the center, which in
Figure 5 emphasizes its key role as the most important predictor which
also has the strongest interactions.
set.seed(1701)
# clustered and filtered network for rf
<- viviRf
intVals diag(intVals) <- NA
# select VIVI values in top 20%
<- quantile(diag(viviRf),.8)
impTresh <- quantile(intVals,.8,na.rm=TRUE)
intThresh <- which(diag(viviRf) > impTresh |
sv apply(intVals, 1, max, na.rm=TRUE) > intThresh)
<- hclust(-as.dist(viviRf[sv,sv]), method="single")
h
viviNetwork(viviRf[sv,sv],
cluster = cutree(h, k = 3), # specify number of groups
layout = igraph::layout_as_star)
In Figure 5, after applying a hierarchical clustering, we can see the strongest mutual interactions have been grouped together. Namley; \(x1\), \(x2\), and \(x4\). The remaining variables are individually clustered.
The pdpVars
function constructs a grid of univariate
PDPs with ICE curves for selected variables. We use ICE curves to assist
in the identification of linear or non-linear effects. The fit, data
frame used to train the model, and the name of the response variable are
required inputs.
In the example below, we select the first five variables from our
created vivid
matrix to display and set the number of ICE
curves to be displayed to be 100, via the nIce
argument.
<- colnames(viviRf)[1:5]
top5 pdpVars(data = myData,
fit = rf,
response = 'y',
vars = top5,
nIce = 100)
By employing a matrix layout, the pdpPairs function generates a
generalized pairs partial dependence plot (GPDP) that encompasses
univariate partial dependence (with ICE curves) on the diagonal,
bivariate partial dependence on the upper diagonal, and a scatterplot of
raw variable values on the lower diagonal, where all colours are
assigned to points and ICE curves by the predicted \(\hat{y}\) value. As with the univariate
PDP, the fit, data frame used to train the model, and the name of the
response variable are required inputs. For a full description of all the
function arguments, see ?vivid::pdpPairs
. In the following
example, we select the first five variables to display and set the
number of shown ICE curves to 100.
set.seed(1701)
pdpPairs(data = myData,
fit = rf,
response = "y",
nmax = 500,
gridSize = 10,
vars = c("x1", "x2", "x3", "x4", "x5"),
nIce = 100)
The pdpZen
function utilizes a space-saving technique
based on graph Eulerians, introduced by Hierholzer and Wiener in 18736 to create
partial dependence plots. We refer to these plots as zen-partial
dependence plots (ZPDP). These plots are based on zigzag expanded
navigation plots, also known as zenplots, which are available in the
zenplots package7. Zenplots were designed to showcase paired
graphs of high-dimensional data with a focus on the most significant 2D
displays. In our version, we display bivariate PDPs that emphasize
variables with the most significant interaction values in a compact
zigzag layout. This format is useful when dealing with high-dimensional
predictor space.
To begin, we show a ZPDP using all the variables in the model.
set.seed(1701)
pdpZen(data = myData, fit = rf, response = "y", nmax = 500, gridSize = 10)
In Fig 8, we can see PDPs laid out in a zigzag structure, with the
most influential variable pairs displayed at the top and generally
decreasing as we move down. In Figure 9, below, we select a subset of
variables to display. In this case we select the first five variables
from the data. The argument zpath
specifies the variables
to be plotted, defaulting to all dataset variables aside from the
response.
set.seed(1701)
pdpZen(data = myData,
fit = rf,
response = "y",
nmax = 500,
gridSize = 10,
zpath = c("x1", "x2", "x3", "x4", "x5"))
We can also create a sequence or sequences of variable paths for use
in pdpZen.
via the zPath
function. The zPath
function takes four arguments. These are: viv
- a matrix of
interaction values, cutoff
- exclude interaction values
below this threshold, method
- a string indicating which
method to use to create the path, and connect
- a logical
value indicating if separate Eulerians should be connected.
You can choose between two methods when using the zPath
function: "greedy.weighted"
and
"strictly.weighted"
. The first method utilizes a greedy
Eulerian path algorithm for connected graphs. This method traverses each
edge at least once, beginning at the highest-weighted edge, and moving
on to the remaining edges while prioritizing the highest-weighted edge.
If the graph has an odd number of nodes, some edges may be visited more
than once, or additional edges may be visited. The second method,
"strictly.weighted"
visits edges in strictly decreasing
order by weight (in this case, interaction values). If the connect
argument is set to TRUE
, the sequences generated by the
strictly weighted method are combined to create a single path. In the
code below, we provide an example of creating zen-paths using the
"strictly.weighted"
method, from the top 10% of interaction
scores in viviRf
(i.e., the created vivid
matrix.)
# set zpaths with different parameters
<- viviRf
intVals diag(intVals) <- NA
<- quantile(intVals, .90, na.rm=TRUE)
intThresh <- zPath(viv = viviRf, cutoff = intThresh, connect = FALSE, method = 'strictly.weighted')
zpSw
set.seed(1701)
pdpZen(data = myData,
fit = rf,
response = "y",
nmax = 500,
gridSize = 10,
zpath = zpSw)
We supply an internal custom predict function called
CVpredictfun
to both importance and interaction
calculations. CVpredictfun
is a wrapper around
CVpredict
from the condvis2
package8.
CVpredict
accepts a broad range of fit classes thus
streamlining the process of calculating variable importance and
interactions.
In situations where the fit class is not handled by
CVpredict
, supplying a custom predict function to the
vivi
function by way of the predictFun
argument allows the agnostic VIVI values to be calculated. In the
following, we provide a small example of using such a fit with
vivid
by using the xgboost
package to create a
gradient boosting machine (GBM). TO begin we build the model.
library("xgboost")
<- xgboost(data = as.matrix(myData[,1:9]),
gbst label = as.matrix(myData[,10]),
nrounds = 100,
verbose = 0)
We then build the vivid
matrix for the GBM fit using a
custom predict function, which must be of the form given in the code
snippet.
# predict function for GBM
<- function(fit, data, ...) predict(fit, as.matrix(data[,1:9]))
pFun
set.seed(1701)
<- vivi(fit = gbst,
viviGBst data = myData,
response = "y",
reorder = FALSE,
normalized = FALSE,
predictFun = pFun,
gridSize = 50,
nmax = 500)
From this we can now create our visualisations. For brevity, we only show the heatmap.
viviHeatmap(mat = viviGBst)
In the following we show examples of how to select different
(embedded) importance metrics for use with the vivi
function. To illustrate the process we use a random forest model fit
using the randomForest
and ranger
packages.
To begin we fit the randomForest
model.
set.seed(1701)
<- randomForest(y ~ ., data = myData, importance = TRUE) rfEmbedded
Note that for a randomForest
model, if the argument
importance = TRUE
, then multiple importance metrics are
returned. In this case, as we have a regression random forest, the
returned importance metrics are the percent increase in mean squared
error (%IncMSE) and the increase in node purity (IncNodePurity). In
order to choose a specific metric for use with vivid
, it is
necessary to specify one of the importance metrics returned by the
random forest as the argument for the importanceType
parameter in the vivi
function. In the code below we select
the %IncMSE as the importance metric.
<- vivi(fit = rfEmbedded,
viviRfEmbedded data = myData,
response = "y",
importanceType = "%IncMSE")
For a ranger
random forest model, the importance metric
must be specified when fitting the model. In the code below, we select
the impurity
as the importance metric.
<- ranger(y~., data = myData, importance = 'impurity') rang
Then when calling the vivi
function, the
importanceType
argument is set to the same selected
importance metric.
<- vivi(fit = rang,
viviRangEmbedded data = myData,
response = "y",
importanceType = "impurity")
In this section, we briefly describe how to apply the above
visualisations to a classification example using the iris
data set.
To begin we fit a ranger
random forest model with
“Species” as the response and create the vivi matrix setting the
category for classification to be “setosa” using class
.
set.seed(1701)
<- ranger(Species~ ., data = iris, probability = T,
rfClassif importance = "impurity")
set.seed(101)
<- vivi(fit = rfClassif,
viviClassif data = iris,
response = "Species",
gridSize = 10,
nmax = 50,
reorder = TRUE,
class = "setosa")
Next we plot the heatmap and network plot of the iris data.
viviHeatmap(mat = viviClassif)
viviNetwork(mat = viviClassif)
As mentioned above, as PDPs are evaluated on a grid and can extrapolate where there is no data. To solve this issue we calculate a convex hull around the data and remove any points that fall outside the convex hull, as shown below.
set.seed(1701)
pdpPairs(data = iris,
fit = rfClassif,
response = "Species",
class = "setosa",
convexHull = T,
gridSize = 10,
nmax = 50)
Alan Inglis and Andrew Parnell and Catherine B. Hurley (2022) Visualizing Variable Importance and Variable Interaction Effects in Machine Learning Models. Journal of Computational and Graphical Statistics (3), pages 1-13↩︎
Friedman, Jerome H. (1991) Multivariate adaptive regression splines. The Annals of Statistics 19 (1), pages 1-67.↩︎
Friedman, J. H. and Popescu, B. E. (2008). “Predictive learning via rule ensembles.” The Annals of Applied Statistics. JSTOR, 916–54.↩︎
Fisher A., Rudin C., Dominici F. (2018). All Models are Wrong but many are Useful: Variable Importance for Black-Box, Proprietary, or Misspecified Prediction Models, using Model Class Reliance. Arxiv.↩︎
Zeileis, Achim, Jason C. Fisher, Kurt Hornik, Ross Ihaka, Claire D. McWhite, Paul Murrell, Reto Stauffer, and Claus O. Wilke. 2020. “Colorspace: A Toolbox for Manipulating and Assessing Colors and Palettes.” Journal of Statistical Software, Articles 96 (1): 1–49↩︎
Hierholzer, Carl, and Chr Wiener. 1873. “Über Die möglichkeit, Einen Linienzug Ohne Wiederholung Und Ohne Unterbrechung Zu Umfahren.” Mathematische Annalen 6 (1): 30–32.↩︎
Hofert, Marius, and Wayne Oldford. 2020. “Zigzag Expanded Navigation Plots in R: The R Package zenplots.” Journal of Statistical Software 95 (4): 1–44.↩︎
Hurley, Catherine, Mark OConnell, and Katarina Domijan. 2022. Condvis2: Interactive Conditional Visualization for Supervised and Unsupervised Models in Shiny.↩︎