Skip to contents

Global explanations for tree-based models by decomposing regression or classification functions into the sum of main components and interaction components of arbitrary order. Calculates SHAP values and q-interaction SHAP for all values of q for tree-based models such as xgboost.

Usage

glex(
  object,
  x,
  max_interaction = NULL,
  features = NULL,
  probFunction = NULL,
  ...
)

# S3 method for class 'rpf'
glex(object, x, max_interaction = NULL, features = NULL, ...)

# S3 method for class 'xgb.Booster'
glex(
  object,
  x,
  max_interaction = NULL,
  features = NULL,
  probFunction = NULL,
  ...
)

# S3 method for class 'ranger'
glex(
  object,
  x,
  max_interaction = NULL,
  features = NULL,
  probFunction = NULL,
  ...
)

Arguments

object

Model to be explained, either of class xgb.Booster or rpf.

x

Data to be explained.

max_interaction

(integer(1): NULL)
Maximum interaction size to consider. Defaults to using all possible interactions available in the model.
For xgboost, this defaults to the max_depth parameter of the model fit.
If not set in xgboost, the default value of 6 is assumed.

features

Vector of column names in x to calculate components for. Default is NULL, i.e. all features are used.

probFunction

Either "path-dependent" to use old path-dependent weighting of leaves or a user specified probability function of the signature function(coords, lb, ub). Defaults to NULL or "emprical", i.e. the empirical marginal probabilities will be used

...

Further arguments passed to methods.

Value

Decomposition of the regression or classification function. A list with elements:

  • shap: SHAP values (xgboost method only).

  • m: Functional decomposition into all main and interaction components in the model, up to the degree specified by max_interaction. The variable names correspond to the original variable names, with : separating interaction terms as one would specify in a formula interface.

  • intercept: Intercept term, the expected value of the prediction.

Details

For parallel execution using xgboost models, register a backend, e.g. with doParallel::registerDoParallel().

Examples


# Random Planted Forest -----
if (requireNamespace("randomPlantedForest", quietly = TRUE)) {
library(randomPlantedForest)

rp <- rpf(mpg ~ ., data = mtcars[1:26, ], max_interaction = 2)

glex_rpf <- glex(rp, mtcars[27:32, ])
str(glex_rpf, list.len = 5)
}
#> List of 3
#>  $ m        :Classes ‘data.table’ and 'data.frame':	6 obs. of  55 variables:
#>   ..$ cyl      : num [1:6] 0.787 0.787 -0.581 -0.184 -0.581 ...
#>   ..$ disp     : num [1:6] 0.547 1.803 -0.959 0.343 -0.62 ...
#>   ..$ hp       : num [1:6] 2.839 0.181 -2.506 -0.618 -2.506 ...
#>   ..$ drat     : num [1:6] 2.94 -0.503 2.94 -0.423 -0.423 ...
#>   ..$ wt       : num [1:6] 1.621 1.569 0.465 0.814 -1.043 ...
#>   .. [list output truncated]
#>   ..- attr(*, ".internal.selfref")=<externalptr> 
#>  $ intercept: num 20.1
#>  $ x        :Classes ‘data.table’ and 'data.frame':	6 obs. of  10 variables:
#>   ..$ cyl : num [1:6] 4 4 8 6 8 4
#>   ..$ disp: num [1:6] 120.3 95.1 351 145 301 ...
#>   ..$ hp  : num [1:6] 91 113 264 175 335 109
#>   ..$ drat: num [1:6] 4.43 3.77 4.22 3.62 3.54 4.11
#>   ..$ wt  : num [1:6] 2.14 1.51 3.17 2.77 3.57 ...
#>   .. [list output truncated]
#>   ..- attr(*, ".internal.selfref")=<externalptr> 
#>  - attr(*, "class")= chr [1:3] "glex" "rpf_components" "list"
# xgboost -----
if (requireNamespace("xgboost", quietly = TRUE)) {
library(xgboost)
x <- as.matrix(mtcars[, -1])
y <- mtcars$mpg
xg <- xgboost(data = x[1:26, ], label = y[1:26],
              params = list(max_depth = 4, eta = .1),
              nrounds = 10, verbose = 0)
glex(xg, x[27:32, ])

if (FALSE) { # \dontrun{
# Parallel execution
doParallel::registerDoParallel()
glex(xg, x[27:32, ])
} # }
}
# ranger -----
if (requireNamespace("ranger", quietly = TRUE)) {
library(ranger)
x <- as.matrix(mtcars[, -1])
y <- mtcars$mpg
rf <- ranger(x = x[1:26, ], y = y[1:26],
             num.trees = 5, max.depth = 3,
             node.stats = TRUE)
glex(rf, x[27:32, ])

if (FALSE) { # \dontrun{
# Parallel execution
doParallel::registerDoParallel()
glex(rf, x[27:32, ])
} # }
}