Skip to contents

randomPlantedForest implements “Random Planted Forest”, a directly interpretable tree ensemble (arxiv).

Installation

You can install the development version of randomPlantedForest from GitHub with

# install.packages("remotes")
remotes::install_github("PlantedML/randomPlantedForest")

or from r-universe with

install.packages("randomPlantedForest", repos = "https://plantedml.r-universe.dev")

Example

Model fitting uses a familiar interface:

library(randomPlantedForest)

mtcars$cyl <- factor(mtcars$cyl)
rpfit <- rpf(mpg ~ cyl + wt + hp, data = mtcars, ntrees = 25, max_interaction = 2)
rpfit
#> -- Regression Random Planted Forest --
#> 
#> Formula: mpg ~ cyl + wt + hp 
#> Fit using 3 predictors and 2-degree interactions.
#> Forest is _not_ purified!
#> 
#> Called with parameters:
#> 
#>             loss: L2
#>           ntrees: 25
#>  max_interaction: 2
#>           splits: 30
#>        split_try: 10
#>            t_try: 0.4
#>            delta: 0
#>          epsilon: 0.1
#>    deterministic: FALSE
#>         nthreads: 1
#>           purify: FALSE
#>               cv: FALSE

predict(rpfit, new_data = mtcars) |>
  cbind(mpg = mtcars$mpg) |>
  head()
#>      .pred  mpg
#> 1 20.81459 21.0
#> 2 20.72354 21.0
#> 3 26.04526 22.8
#> 4 21.26845 21.4
#> 5 18.45921 18.7
#> 6 19.54406 18.1

Prediction components can be accessed via predict_components, including the intercept, main effects, and interactions up to a specified degree. The returned object also contains the original data as x, which is required for visualization. The glex package can be used as well: glex(rpfit) yields the same result.

components <- predict_components(rpfit, new_data = mtcars) 

str(components)
#> List of 3
#>  $ m        :Classes 'data.table' and 'data.frame':  32 obs. of  6 variables:
#>   ..$ cyl   : num [1:32] 0.445 0.445 0.863 0.445 -1.274 ...
#>   ..$ wt    : num [1:32] -0.0615 -0.1421 2.3182 -0.0155 -0.3116 ...
#>   ..$ hp    : num [1:32] 0.162 0.162 2.021 0.162 -0.941 ...
#>   ..$ cyl:wt: num [1:32] 0.00389 0.00389 0.69586 0.17156 0.4615 ...
#>   ..$ cyl:hp: num [1:32] 0.1453 0.1453 -0.0511 0.1453 0.1179 ...
#>   ..$ hp:wt : num [1:32] -0.1264 -0.1367 -0.0487 0.1138 0.1596 ...
#>   ..- attr(*, ".internal.selfref")=<externalptr> 
#>  $ intercept: num 20.2
#>  $ x        :Classes 'data.table' and 'data.frame':  32 obs. of  3 variables:
#>   ..$ cyl: Factor w/ 3 levels "4","6","8": 2 2 1 2 3 2 3 1 1 2 ...
#>   ..$ wt : num [1:32] 2.62 2.88 2.32 3.21 3.44 ...
#>   ..$ hp : num [1:32] 110 110 93 110 175 105 245 62 95 123 ...
#>   ..- attr(*, ".internal.selfref")=<externalptr> 
#>  - attr(*, "class")= chr [1:3] "glex" "rpf_components" "list"

Various visualization options are available via glex, e.g. for main and second-order interaction effects:

# install glex if not available:
if (!requireNamespace("glex")) remotes::install_github("PlantedML/glex")
#> Loading required namespace: glex
library(glex)
library(ggplot2)
library(patchwork) # For plot arrangement

p1 <- autoplot(components, "wt")
p2 <- autoplot(components, "hp")
p3 <- autoplot(components, "cyl")
p4 <- autoplot(components, c("wt", "hp"))

(p1 + p2) / (p3 + p4) +
  plot_annotation(
    title = "Selected effects for mtcars",
    caption = "(It's a tiny dataset but it has to fit in a README, okay?)"
  )

See the Bikesharing decomposition article for more examples.