IML: Machine Learning Model Interpretability And Feature Explanation with IML and H2O
Written by Brad Boehmke
Model interpretability is critical to businesses. If you want to use high performance models (GLM, RF, GBM, Deep Learning, H2O, Keras, xgboost, etc), you need to learn how to explain them. With machine learning interpretability growing in importance, several R packages designed to provide this capability are gaining in popularity. We analyze the IML package in this article.
In recent blog posts we assessed LIME for model agnostic local interpretability functionality and DALEX for both local and global machine learning explanation plots. This post examines the iml
package (short for Interpretable Machine Learning) to assess its functionality in providing machine learning interpretability to help you determine if it should become part of your preferred machine learning toolbox.
We again utilize the high performance machine learning library, h2o
, implementing three popular black-box modeling algorithms: GLM (generalized linear models), RF (random forest), and GBM (gradient boosted machines). For those that want a deep dive into model interpretability, the creator of the iml
package, Christoph Molnar, has put together a free book: Interpretable Machine Learning. Check it out.
Articles In The Model Interpretability Series
Articles related to machine learning and black-box model interpretability:
Awesome Data Science Tutorials with LIME for black-box model explanation in business:
FREE BOOK: Interpretable Machine Learning
The creator of the IML (Interpretable Machine Learning) package has a great FREE resource for those interested in applying model interpretability techniques to complex, black-box models (high performance models). Check out the book: Interpretable Machine Learning by Christoph Molnar.
Interpretable Machine Learning Book by Christoph Molnar
Learning Trajectory
We’ll cover the following topics on iml
, combining with the h2o
high performance machine learning package:
IML and H2O: Machine Learning Model Interpretability And Feature Explanation
By Brad Boehmke, Director of Data Science at 84.51°
Advantages & disadvantages
The iml
package is probably the most robust ML interpretability package available. It provides both global and local model-agnostic interpretation methods. Although the interaction functions are notably slow, the other functions are faster or comparable to existing packages we use or have tested. I definitely recommend adding iml
to your preferred ML toolkit. The following provides a quick list of some of its pros and cons:
Advantages
- ML model and package agnostic: can be used for any supervised ML model (many features are only relevant to regression and binary classification problems).
- Variable importance: uses a permutation-based approach for variable importance, which is model agnostic, and accepts any loss function to assess importance.
- Partial dependence plots: Fast PDP implementation and allows for ICE curves.
- H-statistic: one of only a few implementations to allow for assessing interactions.
- Local interpretation: provides both LIME and Shapley implementations.
- Plots: built with
ggplot2
which allows for easy customization
Disadvantages
- Does not allow for easy comparisons across models like
DALEX
.
- The H-statistic interaction functions do not scale well to wide data (may predictor variables).
- Only provides permutation-based variable importance scores (which become slow as number of features increase).
- LIME implementation has less flexibilty and features than
lime
.
Replication Requirements
Libraries
I leverage the following packages:
Data
To demonstrate iml
’s capabilities we’ll use the employee attrition data that has been included in the rsample
package. This demonstrates a binary classification problem (“Yes” vs. “No”) but the same process that you’ll observe can be used for a regression problem.
I perform a few house cleaning tasks on the data prior to converting to an h2o
object and splitting.
NOTE: The surrogate tree function uses partykit::cpart
, which requires all predictors to be numeric or factors. Consequently, we need to coerce any character predictors to factors (or ordinal encode).
H2O Models
We will explore how to visualize a few of the more common machine learning algorithms implemented with h2o
. For brevity I train default models and do not emphasize hyperparameter tuning. The following produces a regularized logistic regression, random forest, and gradient boosting machine models.
All of the models provide AUCs ranging between 0.75 to 0.79. Although these models have distinct AUC scores, our objective is to understand how these models come to this conclusion in similar or different ways based on underlying logic and data structure.
Although these models have distinct AUC scores, our objective is to understand how these models come to this conclusion in similar or different ways based on underlying logic and data structure.
IML procedures
In order to work with iml
, we need to adapt our data a bit so that we have the following three components:
-
Create a data frame with just the features (must be of class data.frame
, cannot be an H2OFrame
or other class).
-
Create a vector with the actual responses (must be numeric - 0/1 for binary classification problems).
-
iml
has internal support for some machine learning packages (i.e. mlr
, caret
, randomForest
). However, to use iml
with several of the more popular packages being used today (i.e. h2o
, ranger
, xgboost
) we need to create a custom function that will take a data set (again must be of class data.frame
) and provide the predicted values as a vector.
Once we have these three components we can create a predictor object. Similar to DALEX
and lime
, the predictor object holds the model, the data, and the class labels to be applied to downstream functions. A unique characteristic of the iml
package is that it uses R6 classes, which is rather rare. To main differences between R6 classes and the normal S3 and S4 classes we typically work with are:
- Methods belong to objects, not generics (we’ll see this in the next code chunk).
- Objects are mutable: the usual copy-on-modify semantics do not apply (we’ll see this later in this tutorial).
These properties make R6 objects behave more like objects in programming languages such as Python. So to construct a new Predictor
object, you call the new()
method which belongs to the R6 Predictor
object and you use $
to access new()
:
Global interpretation
iml
provides a variety of ways to understand our models from a global perspective:
- Feature Importance
- Partial Dependence
- Measuring Interactions
- Surrogate Model
We’ll go through each.
1. Feature importance
We can measure how important each feature is for the predictions with FeatureImp
. The feature importance measure works by calculating the increase of the model’s prediction error after permuting the feature. A feature is “important” if permuting its values increases the model error, because the model relied on the feature for the prediction. A feature is “unimportant” if permuting its values keeps the model error unchanged, because the model ignored the feature for the prediction. This model agnostic approach is based on (Breiman, 2001; Fisher et al, 2018) and follows the given steps:
For any given loss function do
1: compute loss function for original model
2: for variable i in {1,...,p} do
| randomize values
| apply given ML model
| estimate loss function
| compute feature importance (permuted loss / original loss)
end
3. Sort variables by descending feature importance
We see that all three models find OverTime
as the most influential variable; however, after that each model finds unique structure and signals within the data. Note: you can extract the results with imp.rf$results
.
Permutation-based approaches can become slow as your number of predictors grows. To assess variable importance for all 3 models in this example takes only 8 seconds. However, performing the same procedure on a data set with 80 predictors (AmesHousing::make_ames()
) takes roughly 3 minutes. Although this is slower, it is comparable to other permutation-based implementations (i.e. DALEX
, ranger
).
The following lists some advantages and disadvantages to iml
’s feature importance procedure.
Advantages:
- Model agnostic
- Simple interpretation that’s comparable across models
- Can apply any loss function (accepts custom loss functions as well)
- Plot output uses
ggplot2
; we can add to or use our internal branding packages with it
Disadvantages:
- Permutation-based methods are slow
- Default plot contains all predictors; however, we can subset
$results
data frame if desired
2. Partial dependence
The Partial
class implements partial dependence plots (PDPs) and individual conditional expectation (ICE) curves. The procedure follows the traditional methodology documented in Friedman (2001) and Goldstein et al. (2014) where the ICE curve for a certain feature illustrates the predicted value for each observation when we force each observation to take on the unique values of that feature. The PDP curve represents the average prediction across all observations.
For a selected predictor (x)
1. Determine grid space of j evenly spaced values across distribution of x
2: for value i in {1,...,j} of grid space do
| set x to i for all observations
| apply given ML model
| estimate predicted value
| if PDP: average predicted values across all observations
end
The following produces “ICE boxplots” and a PDP line (connects the averages of all observations for each response class) for the most important variable in all three models (OverTime
). All three model show a sizable increase in the probability of employees attriting when working overtime. However, you will notice the random forest model experiences less of an increase in probability compared to the other two models.
For continuous predictors you can reduce the grid space to make computation time more efficient and center the ICE curves. Note: to produce the centered ICE curves (right plot) you use ice$center
and provide it the value to center on. This will modify the existing object in place (recall this is a unique characteristic of R6 –> objects are mutable). The following compares the marginal impact of age on the probability of attriting. The regularized regression model shows a monotonic decrease in the probability (the log-odds probability is linear) while the two tree-based approaches capture the non-linear, non-monotonic relationship.
Similar to pdp
you can also compute and plot 2-way interactions. Here we assess how the interaction of MonthlyIncome
and OverTime
influences the predicted probability of attrition for all three models.
The following lists some advantages and disadvantages to iml
’s PDP and ICE procedures.
Advantages:
- Provides PDPs & ICE curves (unlike
DALEX
)
- Allows you to center ICE curves
- Computationally efficient
grid.size
allows you to increase/reduce grid space of xi values
- Rug option illustrates distribution of all xi values
- Provides convenient plot outputs for categorical predictors
Disadvantages:
- Only provides heatmap plot of 2-way interaction plots
- Does not allow for easy comparison across models like
DALEX
3. Measuring Interactions
A wonderful feature provided by iml
is to measure how strongly features interact with each other. To measure interaction, iml
uses the H-statistic proposed by Friedman and Popescu (2008). The H-statistic measures how much of the variation of the predicted outcome depends on the interaction of the features. There are two approaches to measure this. The first measures if a feature (xi) interacts with any other feature. The algorithm performs the following steps:
1: for variable i in {1,...,p} do
| f(x) = estimate predicted values with original model
| pd(x) = partial dependence of variable i
| pd(!x) = partial dependence of all features excluding i
| upper = sum(f(x) - pd(x) - pd(!x))
| lower = variance(f(x))
| rho = upper / lower
end
2. Sort variables by descending rho (interaction strength)
The intereaction strength (rho) will be between 0 when there is no interaction at all and 1 if all of variation of the predicted outcome depends on a given interaction. All three models capture different interaction structures although some commonalities exist for different models (i.e. OverTime
, Age
, JobRole
). The interaction effects are stronger in the tree based models versus the GLM model, with the GBM model having the strongest interaction effect of 0.4.
Considering OverTime
exhibits the strongest interaction signal, the next question is which other variable is this interaction the strongest. The second interaction approach measures the 2-way interaction strength of feature xi and xj and performs the following steps:
1: i = a selected variable of interest
2: for remaining variables j in {1,...,p} do
| pd(ij) = interaction partial dependence of variables i and j
| pd(i) = partial dependence of variable i
| pd(j) = partial dependence of variable j
| upper = sum(pd(ij) - pd(i) - pd(j))
| lower = variance(pd(ij))
| rho = upper / lower
end
3. Sort interaction relationship by descending rho (interaction strength)
The following measures the two-way interactions of all variables with the OverTime
variable. The two tree-based models show MonthlyIncome
having the strongest interaction (although it is a week interaction since rho < 0.13). Identifying these interactions can be useful to understand which variables create co-denpendencies in our models behavior. It also helps us identify interactions to visualize with PDPs (which is why I showed the example of the OverTime
and MonthlyIncome
interaction PDP earlier).
The H-statistic is not widely implemented so having this feature in iml
is beneficial. However, its important to note that as your feature set grows, the H-statistic becomes computationally slow. For this data set, measuring the interactions across all three models only took 45 seconds and 68 seconds for the two-way interactions. However, for a wider data set such as AmesHousing::make_ames()
where there are 80 predictors, this will up towards an hour to compute.
4. Surrogate model
Another way to make the models more interpretable is to replace the “black box” model with a simpler model (aka “white box” model) such as a decision tree. This is known as a surrogate model in which we
1. apply original model and get predictions
2. choose an interpretable "white box" model (linear model, decision tree)
3. Train the interpretable model on the original dataset and its predictions
4. Measure how well the surrogate model replicates the prediction of the black box model
5. Interpret / visualize the surrogate model
iml
provides a simple decision tree surrogate approach, which leverages partykit::cpart
. In this example we train a CART decision tree with max depth of 3 on our GBM model. You can see that the white box model does not do a good job of explaining the black box predictions (R^2 = 0.438).
The plot illustrates the distribution of the probability of attrition for each terminal node. We see an employee with JobLevel
> 1 and DistanceFromHome
<= 12 has a very low probability of attriting.
When trying to explain a complicated machine learning model to decision makers, surrogate models can help simplify the process. However, its important to only use surrogate models for simplified explanations when they are actually good representatives of the black box model (in this example it is not).
Local interpretation
In addition to providing global explanations of ML models, iml
also provides two newer, but well accepted methods for local interpretation.
Local interpretation techniques provide methods to explain why an individual prediction was made for a given observation.
To illustrate, lets get two observations. The first represents the observation that our random forest model produces the highest probability of a attrition (observation 154 has a 0.666 probability of attrition) and the second represents the observation with the lowest probability (observation 28 has a 0 probability of attrition).
1. Lime
iml
implements its own version of local interpretable model-agnostic explanations (Ribeiro et al., 2016). Although it does not use the lime
package directly, it does implement the same procedures (see lime
tutorial).
A few notable items about iml
implementation (see referenced tutorial above for these details within lime
):
- like
lime
, can change distance metric (default is gower but accepts all distance functions provided by ?dist
),
- like
lime
, can change kernel (neighborhood size),
- like
lime
, computationally efficient –> takes about 5 seconds to compute,
- like
lime
, can be applied to multinomial responses,
- like
lime
, uses the glmnet
package to fit the local model; however…
- unlike
lime
, only implements a ridge model (lime
allows ridge, lasso, and more),
- unlike
lime
, can only do one observation at a time (lime
can do multiple),
- unlike
lime
, does not provide fit metric such as (R^2) for the local model.
The following fits a local model for the observation with the highest probability of attrition. In this example I look for the 10 variables in each model that are most influential in this observations predicted value (k = 10
). The results show that the Age
of the employee reduces the probability of attrition within all three models. Morever, all three models show that since this employee works OverTime
, this is having a sizable increase in the probability of attrition. However, the tree-based models also identify the MaritalStatus
and JobRole
of this employee contributing to his/her increased probability of attrition.
Here, I reapply the same model to low_prob_ob
. Here, we see Age
, JobLevel
, and OverTime
all having sizable influence on this employees very low predicted probability of attrition (zero).
Although, LocalModel
does not provide the fit metrics (i.e. R^2) for our local model, we can compare the local models predicted probability to the global (full) model’s predicted probability.
For the high probability employee, the local model only predicts a 0.34 probability of attrition whereas the local model predicts a more accurate 0.12 probability of attrition for the low probability employee. This can help you guage the trustworthiness of the local model.
High Probability:
Low Probability:
2. Shapley values
An alternative for explaining individual predictions is a method from coalitional game theory that produces whats called Shapley values (Lundberg & Lee, 2016). The idea behind Shapley values is to assess every combination of predictors to determine each predictors impact. Focusing on feature xj, the approach will test the accuracy of every combination of features not including xj and then test how adding xj to each combination improves the accuracy. Unfortunately, computing Shapley values is very computationally expensive. Consequently, iml
implements an approximate Shapley estimation algorithm that follows the following steps:
ob = single observation of interest
1: for variables j in {1,...,p} do
| m = random sample from data set
| t = rbind(m, ob)
| f(all) = compute predictions for t
| f(!j) = compute predictions for t with feature j values randomized
| diff = sum(f(all) - f(!j))
| phi = mean(diff)
end
2. sort phi in decreasing order
The Shapley value ($\phi$) represents the contribution of each respective variable towards the predicted valued compared to the average prediction for the data set.
We use Shapley$new
to create a new Shapley object. For this data set it takes about 9 seconds to compute. The time to compute is largely driven by the number of predictors but you can also control the sample size drawn (see sample.size
argument) to help reduce compute time. If you look at the results, you will see that the predicted value of 0.667 is 0.496 larger than the average prediction of 0.17. The plot displays the contribution each predictor played in this difference of 0.496.
We can compare the Shapley values across each model to see if common themes appear. Again, OverTime
is a common theme across all three models. We also see MonthlyIncome
influential for the tree-based methods and there are other commonalities for the mildly influential variables across all three models (i.e. StockOptionLevel
, JobLevel
, Age
, MaritalStatus
).
Similarly, we can apply for the low probability employee. Some common themes pop out for this employee as well. It appears that the age, total number of working years, and the senior position (JobLevel
, JobRole
) play a large part in the low predicted probability of attrition for this employee.
Shapley values are considered more robust than the results you will get from LIME. However, similar to the different ways you can compute variable importance, although you will see differences between the two methods often you will see common variables being identified as highly influential in both approaches. Consequently, we should use these approaches to help indicate influential variables but not to definitively label a variables as the most influential.